From 823fb58a3e8c68d3b1fede3bff5bf31650a58a3a Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Wed, 20 Mar 2024 13:14:09 +0000 Subject: [PATCH] [feature/semantic_text] Refactor inference to run as an action filter (#106357) --------- Co-authored-by: carlosdelest --- .../action/bulk/BulkOperation.java | 118 +-- .../action/bulk/BulkShardRequest.java | 27 + .../BulkShardRequestInferenceProvider.java | 338 --------- .../action/bulk/TransportBulkAction.java | 36 +- .../bulk/TransportSimulateBulkAction.java | 4 +- .../vectors/DenseVectorFieldMapper.java | 4 + .../inference/InferenceServiceRegistry.java | 62 +- .../InferenceServiceRegistryImpl.java | 64 -- .../inference/ModelRegistry.java | 99 --- .../elasticsearch/node/NodeConstruction.java | 15 - .../plugins/InferenceRegistryPlugin.java | 22 - .../action/bulk/BulkOperationTests.java | 670 ------------------ ...ActionIndicesThatCannotBeCreatedTests.java | 8 +- .../bulk/TransportBulkActionIngestTests.java | 8 +- .../action/bulk/TransportBulkActionTests.java | 4 +- .../bulk/TransportBulkActionTookTests.java | 16 +- .../snapshots/SnapshotResiliencyTests.java | 4 +- .../TestSparseInferenceServiceExtension.java | 8 +- ...gistryImplIT.java => ModelRegistryIT.java} | 52 +- .../xpack/inference/InferencePlugin.java | 54 +- .../TransportDeleteInferenceModelAction.java | 2 +- .../TransportGetInferenceModelAction.java | 2 +- .../action/TransportInferenceAction.java | 2 +- .../TransportPutInferenceModelAction.java | 2 +- .../ShardBulkInferenceActionFilter.java | 343 +++++++++ ...r.java => InferenceResultFieldMapper.java} | 68 +- .../mapper/SemanticTextFieldMapper.java | 2 +- .../mapper}/SemanticTextModelSettings.java | 11 +- ...elRegistryImpl.java => ModelRegistry.java} | 82 ++- .../ShardBulkInferenceActionFilterTests.java | 344 +++++++++ ...a => InferenceResultFieldMapperTests.java} | 147 ++-- ...ImplTests.java => ModelRegistryTests.java} | 34 +- .../inference/10_semantic_text_inference.yml | 48 +- .../20_semantic_text_field_mapper.yml | 20 +- 34 files changed, 1102 insertions(+), 1618 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java delete mode 100644 server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java delete mode 100644 server/src/main/java/org/elasticsearch/inference/ModelRegistry.java delete mode 100644 server/src/main/java/org/elasticsearch/plugins/InferenceRegistryPlugin.java delete mode 100644 server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java rename x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/{ModelRegistryImplIT.java => ModelRegistryIT.java} (86%) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/{SemanticTextInferenceResultFieldMapper.java => InferenceResultFieldMapper.java} (84%) rename {server/src/main/java/org/elasticsearch/inference => x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper}/SemanticTextModelSettings.java (92%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/{ModelRegistryImpl.java => ModelRegistry.java} (86%) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/{SemanticTextInferenceResultFieldMapperTests.java => InferenceResultFieldMapperTests.java} (79%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/{ModelRegistryImplTests.java => ModelRegistryTests.java} (92%) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index 2b84ec8746cd2..452a9ec90443a 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; @@ -36,8 +35,6 @@ import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndexClosedException; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; @@ -47,7 +44,6 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; -import java.util.function.BiConsumer; import java.util.function.LongSupplier; import static org.elasticsearch.cluster.metadata.IndexNameExpressionResolver.EXCLUDED_DATA_STREAMS_KEY; @@ -73,8 +69,6 @@ final class BulkOperation extends ActionRunnable { private final LongSupplier relativeTimeProvider; private IndexNameExpressionResolver indexNameExpressionResolver; private NodeClient client; - private final InferenceServiceRegistry inferenceServiceRegistry; - private final ModelRegistry modelRegistry; BulkOperation( Task task, @@ -88,8 +82,6 @@ final class BulkOperation extends ActionRunnable { IndexNameExpressionResolver indexNameExpressionResolver, LongSupplier relativeTimeProvider, long startTimeNanos, - ModelRegistry modelRegistry, - InferenceServiceRegistry inferenceServiceRegistry, ActionListener listener ) { super(listener); @@ -105,8 +97,6 @@ final class BulkOperation extends ActionRunnable { this.relativeTimeProvider = relativeTimeProvider; this.indexNameExpressionResolver = indexNameExpressionResolver; this.client = client; - this.inferenceServiceRegistry = inferenceServiceRegistry; - this.modelRegistry = modelRegistry; this.observer = new ClusterStateObserver(clusterService, bulkRequest.timeout(), logger, threadPool.getThreadContext()); } @@ -199,30 +189,7 @@ private void executeBulkRequestsByShard(Map> requ return; } - BulkShardRequestInferenceProvider.getInstance( - inferenceServiceRegistry, - modelRegistry, - clusterState, - requestsByShard.keySet(), - new ActionListener() { - @Override - public void onResponse(BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider) { - processRequestsByShards(requestsByShard, clusterState, bulkShardRequestInferenceProvider); - } - - @Override - public void onFailure(Exception e) { - throw new ElasticsearchException("Error loading inference models", e); - } - } - ); - } - - void processRequestsByShards( - Map> requestsByShard, - ClusterState clusterState, - BulkShardRequestInferenceProvider bulkShardRequestInferenceProvider - ) { + String nodeId = clusterService.localNode().getId(); Runnable onBulkItemsComplete = () -> { listener.onResponse( new BulkResponse(responses.toArray(new BulkItemResponse[responses.length()]), buildTookInMillis(startTimeNanos)) @@ -230,68 +197,33 @@ void processRequestsByShards( // Allow memory for bulk shard request items to be reclaimed before all items have been completed bulkRequest = null; }; + try (RefCountingRunnable bulkItemRequestCompleteRefCount = new RefCountingRunnable(onBulkItemsComplete)) { for (Map.Entry> entry : requestsByShard.entrySet()) { final ShardId shardId = entry.getKey(); final List requests = entry.getValue(); - BulkShardRequest bulkShardRequest = createBulkShardRequest(clusterState, shardId, requests); - - Releasable ref = bulkItemRequestCompleteRefCount.acquire(); - final BiConsumer bulkItemFailedListener = (itemReq, e) -> markBulkItemRequestFailed(itemReq, e); - bulkShardRequestInferenceProvider.processBulkShardRequest(bulkShardRequest, new ActionListener<>() { - @Override - public void onResponse(BulkShardRequest inferenceBulkShardRequest) { - executeBulkShardRequest( - inferenceBulkShardRequest, - ActionListener.releaseAfter(ActionListener.noop(), ref), - bulkItemFailedListener - ); - } - @Override - public void onFailure(Exception e) { - throw new ElasticsearchException("Error performing inference", e); - } - }, bulkItemFailedListener); + BulkShardRequest bulkShardRequest = new BulkShardRequest( + shardId, + bulkRequest.getRefreshPolicy(), + requests.toArray(new BulkItemRequest[0]) + ); + var indexMetadata = clusterState.getMetadata().index(shardId.getIndexName()); + if (indexMetadata != null && indexMetadata.getFieldInferenceMetadata().isEmpty() == false) { + bulkShardRequest.setFieldInferenceMetadata(indexMetadata.getFieldInferenceMetadata()); + } + bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); + bulkShardRequest.timeout(bulkRequest.timeout()); + bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); + if (task != null) { + bulkShardRequest.setParentTask(nodeId, task.getId()); + } + executeBulkShardRequest(bulkShardRequest, bulkItemRequestCompleteRefCount.acquire()); } } } - private BulkShardRequest createBulkShardRequest(ClusterState clusterState, ShardId shardId, List requests) { - BulkShardRequest bulkShardRequest = new BulkShardRequest( - shardId, - bulkRequest.getRefreshPolicy(), - requests.toArray(new BulkItemRequest[0]) - ); - bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); - bulkShardRequest.timeout(bulkRequest.timeout()); - bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); - if (task != null) { - bulkShardRequest.setParentTask(clusterService.localNode().getId(), task.getId()); - } - return bulkShardRequest; - } - - // When an item fails, store the failure in the responses array - private void markBulkItemRequestFailed(BulkItemRequest itemRequest, Exception e) { - final String indexName = itemRequest.index(); - - DocWriteRequest docWriteRequest = itemRequest.request(); - BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); - responses.set(itemRequest.id(), BulkItemResponse.failure(itemRequest.id(), docWriteRequest.opType(), failure)); - } - - private void executeBulkShardRequest( - BulkShardRequest bulkShardRequest, - ActionListener listener, - BiConsumer bulkItemErrorListener - ) { - if (bulkShardRequest.items().length == 0) { - // No requests to execute due to previous errors, terminate early - listener.onResponse(bulkShardRequest); - return; - } - + private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, Releasable releaseOnFinish) { client.executeLocally(TransportShardBulkAction.TYPE, bulkShardRequest, new ActionListener<>() { @Override public void onResponse(BulkShardResponse bulkShardResponse) { @@ -302,17 +234,19 @@ public void onResponse(BulkShardResponse bulkShardResponse) { } responses.set(bulkItemResponse.getItemId(), bulkItemResponse); } - listener.onResponse(bulkShardRequest); + releaseOnFinish.close(); } @Override public void onFailure(Exception e) { // create failures for all relevant requests - BulkItemRequest[] items = bulkShardRequest.items(); - for (BulkItemRequest item : items) { - bulkItemErrorListener.accept(item, e); + for (BulkItemRequest request : bulkShardRequest.items()) { + final String indexName = request.index(); + DocWriteRequest docWriteRequest = request.request(); + BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e); + responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); } - listener.onFailure(e); + releaseOnFinish.close(); } }); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java index bd929b9a2204e..39fa791a3e27d 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.support.replication.ReplicatedWriteRequest; import org.elasticsearch.action.support.replication.ReplicationRequest; import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.set.Sets; @@ -33,6 +34,8 @@ public final class BulkShardRequest extends ReplicatedWriteRequest i.readOptionalWriteable(inpt -> new BulkItemRequest(shardId, inpt)), BulkItemRequest[]::new); @@ -44,6 +47,30 @@ public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRe setRefreshPolicy(refreshPolicy); } + /** + * Public for test + * Set the transient metadata indicating that this request requires running inference before proceeding. + */ + public void setFieldInferenceMetadata(FieldInferenceMetadata fieldsInferenceMetadata) { + this.fieldsInferenceMetadataMap = fieldsInferenceMetadata; + } + + /** + * Consumes the inference metadata to execute inference on the bulk items just once. + */ + public FieldInferenceMetadata consumeFieldInferenceMetadata() { + FieldInferenceMetadata ret = fieldsInferenceMetadataMap; + fieldsInferenceMetadataMap = null; + return ret; + } + + /** + * Public for test + */ + public FieldInferenceMetadata getFieldsInferenceMetadataMap() { + return fieldsInferenceMetadataMap; + } + public long totalSizeInBytes() { long totalSizeInBytes = 0; for (int i = 0; i < items.length; i++) { diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java deleted file mode 100644 index e80530f75cf4b..0000000000000 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequestInferenceProvider.java +++ /dev/null @@ -1,338 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.action.bulk; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.DocWriteRequest; -import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.support.RefCountingRunnable; -import org.elasticsearch.action.update.UpdateRequest; -import org.elasticsearch.cluster.ClusterState; -import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; -import org.elasticsearch.common.TriConsumer; -import org.elasticsearch.core.Releasable; -import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelRegistry; -import org.elasticsearch.inference.SemanticTextModelSettings; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.function.BiConsumer; -import java.util.stream.Collectors; - -/** - * Performs inference on a {@link BulkShardRequest}, updating the source of each document with the inference results. - */ -public class BulkShardRequestInferenceProvider { - - // Root field name for storing inference results - public static final String ROOT_INFERENCE_FIELD = "_semantic_text_inference"; - - // Contains the original text for the field - - public static final String INFERENCE_RESULTS = "inference_results"; - public static final String INFERENCE_CHUNKS_RESULTS = "inference"; - public static final String INFERENCE_CHUNKS_TEXT = "text"; - - private final ClusterState clusterState; - private final Map inferenceProvidersMap; - - private record InferenceProvider(Model model, InferenceService service) { - private InferenceProvider { - Objects.requireNonNull(model); - Objects.requireNonNull(service); - } - } - - BulkShardRequestInferenceProvider(ClusterState clusterState, Map inferenceProvidersMap) { - this.clusterState = clusterState; - this.inferenceProvidersMap = inferenceProvidersMap; - } - - public static void getInstance( - InferenceServiceRegistry inferenceServiceRegistry, - ModelRegistry modelRegistry, - ClusterState clusterState, - Set shardIds, - ActionListener listener - ) { - Set inferenceIds = shardIds.stream() - .map(ShardId::getIndex) - .collect(Collectors.toSet()) - .stream() - .map(index -> clusterState.metadata().index(index).getFieldInferenceMetadata().getFieldInferenceOptions().values()) - .flatMap(o -> o.stream().map(FieldInferenceMetadata.FieldInferenceOptions::inferenceId)) - .collect(Collectors.toSet()); - final Map inferenceProviderMap = new ConcurrentHashMap<>(); - Runnable onModelLoadingComplete = () -> listener.onResponse( - new BulkShardRequestInferenceProvider(clusterState, inferenceProviderMap) - ); - try (var refs = new RefCountingRunnable(onModelLoadingComplete)) { - for (var inferenceId : inferenceIds) { - ActionListener modelLoadingListener = new ActionListener<>() { - @Override - public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { - var service = inferenceServiceRegistry.getService(unparsedModel.service()); - if (service.isEmpty() == false) { - InferenceProvider inferenceProvider = new InferenceProvider( - service.get() - .parsePersistedConfigWithSecrets( - inferenceId, - unparsedModel.taskType(), - unparsedModel.settings(), - unparsedModel.secrets() - ), - service.get() - ); - inferenceProviderMap.put(inferenceId, inferenceProvider); - } - } - - @Override - public void onFailure(Exception e) { - // Failure on loading a model should not prevent the rest from being loaded and used. - // When the model is actually retrieved via the inference ID in the inference process, it will fail - // and the user will get the details on the inference failure. - } - }; - - modelRegistry.getModelWithSecrets(inferenceId, ActionListener.releaseAfter(modelLoadingListener, refs.acquire())); - } - } - } - - /** - * Performs inference on the fields that have inference models for a bulk shard request. Bulk items from - * the original request will be modified with the inference results, to avoid copying the entire requests from - * the original bulk request. - * - * @param bulkShardRequest original BulkShardRequest that will be modified with inference results. - * @param listener listener to be called when the inference process is finished with the new BulkShardRequest, - * which may have fewer items than the original because of inference failures - * @param onBulkItemFailure invoked when a bulk item fails inference - */ - public void processBulkShardRequest( - BulkShardRequest bulkShardRequest, - ActionListener listener, - BiConsumer onBulkItemFailure - ) { - - Map> fieldsForInferenceIds = getFieldsForInferenceIds( - clusterState.metadata().index(bulkShardRequest.shardId().getIndex()).getFieldInferenceMetadata().getFieldInferenceOptions() - ); - // No inference fields? Terminate early - if (fieldsForInferenceIds.isEmpty()) { - listener.onResponse(bulkShardRequest); - return; - } - - Set failedItems = Collections.synchronizedSet(new HashSet<>()); - Runnable onInferenceComplete = () -> { - if (failedItems.isEmpty()) { - listener.onResponse(bulkShardRequest); - return; - } - // Remove failed items from the original bulk shard request - BulkItemRequest[] originalItems = bulkShardRequest.items(); - BulkItemRequest[] newItems = new BulkItemRequest[originalItems.length - failedItems.size()]; - for (int i = 0, j = 0; i < originalItems.length; i++) { - if (failedItems.contains(i) == false) { - newItems[j++] = originalItems[i]; - } - } - BulkShardRequest newBulkShardRequest = new BulkShardRequest( - bulkShardRequest.shardId(), - bulkShardRequest.getRefreshPolicy(), - newItems - ); - listener.onResponse(newBulkShardRequest); - }; - TriConsumer onBulkItemFailureWithIndex = (bulkItemRequest, i, e) -> { - failedItems.add(i); - onBulkItemFailure.accept(bulkItemRequest, e); - }; - try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) { - BulkItemRequest[] items = bulkShardRequest.items(); - for (int i = 0; i < items.length; i++) { - BulkItemRequest bulkItemRequest = items[i]; - // Bulk item might be null because of previous errors, skip in that case - if (bulkItemRequest != null) { - performInferenceOnBulkItemRequest( - bulkItemRequest, - fieldsForInferenceIds, - i, - onBulkItemFailureWithIndex, - bulkItemReqRef.acquire() - ); - } - } - } - } - - private static Map> getFieldsForInferenceIds( - Map fieldInferenceMap - ) { - Map> fieldsForInferenceIdsMap = new HashMap<>(); - for (Map.Entry entry : fieldInferenceMap.entrySet()) { - String fieldName = entry.getKey(); - String inferenceId = entry.getValue().inferenceId(); - - // Get or create the set associated with the inferenceId - Set fields = fieldsForInferenceIdsMap.computeIfAbsent(inferenceId, k -> new HashSet<>()); - fields.add(fieldName); - } - - return fieldsForInferenceIdsMap; - } - - @SuppressWarnings("unchecked") - private void performInferenceOnBulkItemRequest( - BulkItemRequest bulkItemRequest, - Map> fieldsForModels, - Integer itemIndex, - TriConsumer onBulkItemFailure, - Releasable releaseOnFinish - ) { - - DocWriteRequest docWriteRequest = bulkItemRequest.request(); - Map sourceMap = null; - if (docWriteRequest instanceof IndexRequest indexRequest) { - sourceMap = indexRequest.sourceAsMap(); - } else if (docWriteRequest instanceof UpdateRequest updateRequest) { - sourceMap = updateRequest.docAsUpsert() ? updateRequest.upsertRequest().sourceAsMap() : updateRequest.doc().sourceAsMap(); - } - if (sourceMap == null || sourceMap.isEmpty()) { - releaseOnFinish.close(); - return; - } - final Map docMap = new ConcurrentHashMap<>(sourceMap); - - // When a document completes processing, update the source with the inference - try (var docRef = new RefCountingRunnable(() -> { - if (docWriteRequest instanceof IndexRequest indexRequest) { - indexRequest.source(docMap); - } else if (docWriteRequest instanceof UpdateRequest updateRequest) { - if (updateRequest.docAsUpsert()) { - updateRequest.upsertRequest().source(docMap); - } else { - updateRequest.doc().source(docMap); - } - } - releaseOnFinish.close(); - })) { - - Map rootInferenceFieldMap; - try { - rootInferenceFieldMap = (Map) docMap.computeIfAbsent( - ROOT_INFERENCE_FIELD, - k -> new HashMap() - ); - } catch (ClassCastException e) { - onBulkItemFailure.apply( - bulkItemRequest, - itemIndex, - new IllegalArgumentException("Inference result field [" + ROOT_INFERENCE_FIELD + "] is not an object") - ); - return; - } - - for (Map.Entry> fieldModelsEntrySet : fieldsForModels.entrySet()) { - String modelId = fieldModelsEntrySet.getKey(); - List inferenceFieldNames = getFieldNamesForInference(fieldModelsEntrySet, docMap); - if (inferenceFieldNames.isEmpty()) { - continue; - } - - InferenceProvider inferenceProvider = inferenceProvidersMap.get(modelId); - if (inferenceProvider == null) { - onBulkItemFailure.apply( - bulkItemRequest, - itemIndex, - new IllegalArgumentException("No inference provider found for model ID " + modelId) - ); - return; - } - ActionListener inferenceResultsListener = new ActionListener<>() { - @Override - public void onResponse(InferenceServiceResults results) { - if (results == null) { - onBulkItemFailure.apply( - bulkItemRequest, - itemIndex, - new IllegalArgumentException( - "No inference results retrieved for model ID " + modelId + " in document " + docWriteRequest.id() - ) - ); - } - - int i = 0; - for (InferenceResults inferenceResults : results.transformToCoordinationFormat()) { - String inferenceFieldName = inferenceFieldNames.get(i++); - Map inferenceFieldResult = new LinkedHashMap<>(); - inferenceFieldResult.putAll(new SemanticTextModelSettings(inferenceProvider.model).asMap()); - inferenceFieldResult.put( - INFERENCE_RESULTS, - List.of( - Map.of( - INFERENCE_CHUNKS_RESULTS, - inferenceResults.asMap("output").get("output"), - INFERENCE_CHUNKS_TEXT, - docMap.get(inferenceFieldName) - ) - ) - ); - rootInferenceFieldMap.put(inferenceFieldName, inferenceFieldResult); - } - } - - @Override - public void onFailure(Exception e) { - onBulkItemFailure.apply(bulkItemRequest, itemIndex, e); - } - }; - inferenceProvider.service() - .infer( - inferenceProvider.model, - inferenceFieldNames.stream().map(docMap::get).map(String::valueOf).collect(Collectors.toList()), - // TODO check for additional settings needed - Map.of(), - InputType.INGEST, - ActionListener.releaseAfter(inferenceResultsListener, docRef.acquire()) - ); - } - } - } - - private static List getFieldNamesForInference(Map.Entry> fieldModelsEntrySet, Map docMap) { - List inferenceFieldNames = new ArrayList<>(); - for (String inferenceField : fieldModelsEntrySet.getValue()) { - Object fieldValue = docMap.get(inferenceField); - - // Perform inference on string, non-null values - if (fieldValue instanceof String) { - inferenceFieldNames.add(inferenceField); - } - } - return inferenceFieldNames; - } -} diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index b05464b3a10c2..a2445e95a572f 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -57,8 +57,6 @@ import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.index.VersionType; import org.elasticsearch.indices.SystemIndices; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; @@ -100,8 +98,6 @@ public class TransportBulkAction extends HandledTransportAction responses = new AtomicArray<>(bulkRequest.requests.size()); // Optimizing when there are no prerequisite actions if (indicesToAutoCreate.isEmpty() && dataStreamsToBeRolledOver.isEmpty()) { - executeBulk(task, bulkRequest, startTime, executorName, responses, indicesThatCannotBeCreated, listener); + executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); return; } Runnable executeBulkRunnable = () -> threadPool.executor(executorName).execute(new ActionRunnable<>(listener) { @Override protected void doRun() { - executeBulk(task, bulkRequest, startTime, executorName, responses, indicesThatCannotBeCreated, listener); + executeBulk(task, bulkRequest, startTime, listener, executorName, responses, indicesThatCannotBeCreated); } }); try (RefCountingRunnable refs = new RefCountingRunnable(executeBulkRunnable)) { @@ -649,10 +633,10 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { new BulkOperation( task, @@ -666,8 +650,6 @@ void executeBulk( indexNameExpressionResolver, relativeTimeProvider, startTimeNanos, - modelRegistry, - inferenceServiceRegistry, listener ).run(); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java index c8dc3e7b7ffd5..f65d0f462fde6 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java @@ -58,9 +58,7 @@ public TransportSimulateBulkAction( indexNameExpressionResolver, indexingPressure, systemIndices, - System::nanoTime, - null, - null + System::nanoTime ); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 85221896f35fd..f4a9e1727abd6 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -1112,6 +1112,10 @@ public String typeName() { return CONTENT_TYPE; } + public Integer getDims() { + return dims; + } + @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { if (format != null) { diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java index ce6f1b21b734c..d5973807d9d78 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java @@ -13,41 +13,49 @@ import java.io.Closeable; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class InferenceServiceRegistry implements Closeable { + + private final Map services; + private final List namedWriteables = new ArrayList<>(); + + public InferenceServiceRegistry( + List inferenceServicePlugins, + InferenceServiceExtension.InferenceServiceFactoryContext factoryContext + ) { + // TODO check names are unique + services = inferenceServicePlugins.stream() + .flatMap(r -> r.getInferenceServiceFactories().stream()) + .map(factory -> factory.create(factoryContext)) + .collect(Collectors.toMap(InferenceService::name, Function.identity())); + } -public interface InferenceServiceRegistry extends Closeable { - void init(Client client); - - Map getServices(); - - Optional getService(String serviceName); - - List getNamedWriteables(); - - class NoopInferenceServiceRegistry implements InferenceServiceRegistry { - public NoopInferenceServiceRegistry() {} + public void init(Client client) { + services.values().forEach(s -> s.init(client)); + } - @Override - public void init(Client client) {} + public Map getServices() { + return services; + } - @Override - public Map getServices() { - return Map.of(); - } + public Optional getService(String serviceName) { + return Optional.ofNullable(services.get(serviceName)); + } - @Override - public Optional getService(String serviceName) { - return Optional.empty(); - } + public List getNamedWriteables() { + return namedWriteables; + } - @Override - public List getNamedWriteables() { - return List.of(); + @Override + public void close() throws IOException { + for (var service : services.values()) { + service.close(); } - - @Override - public void close() throws IOException {} } } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java deleted file mode 100644 index f0a990ded98ce..0000000000000 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistryImpl.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.inference; - -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.function.Function; -import java.util.stream.Collectors; - -public class InferenceServiceRegistryImpl implements InferenceServiceRegistry { - - private final Map services; - private final List namedWriteables = new ArrayList<>(); - - public InferenceServiceRegistryImpl( - List inferenceServicePlugins, - InferenceServiceExtension.InferenceServiceFactoryContext factoryContext - ) { - // TODO check names are unique - services = inferenceServicePlugins.stream() - .flatMap(r -> r.getInferenceServiceFactories().stream()) - .map(factory -> factory.create(factoryContext)) - .collect(Collectors.toMap(InferenceService::name, Function.identity())); - } - - @Override - public void init(Client client) { - services.values().forEach(s -> s.init(client)); - } - - @Override - public Map getServices() { - return services; - } - - @Override - public Optional getService(String serviceName) { - return Optional.ofNullable(services.get(serviceName)); - } - - @Override - public List getNamedWriteables() { - return namedWriteables; - } - - @Override - public void close() throws IOException { - for (var service : services.values()) { - service.close(); - } - } -} diff --git a/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java b/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java deleted file mode 100644 index fa90d5ba6f756..0000000000000 --- a/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.inference; - -import org.elasticsearch.action.ActionListener; - -import java.util.List; -import java.util.Map; - -public interface ModelRegistry { - - /** - * Get a model. - * Secret settings are not included - * @param inferenceEntityId Model to get - * @param listener Model listener - */ - void getModel(String inferenceEntityId, ActionListener listener); - - /** - * Get a model with its secret settings - * @param inferenceEntityId Model to get - * @param listener Model listener - */ - void getModelWithSecrets(String inferenceEntityId, ActionListener listener); - - /** - * Get all models of a particular task type. - * Secret settings are not included - * @param taskType The task type - * @param listener Models listener - */ - void getModelsByTaskType(TaskType taskType, ActionListener> listener); - - /** - * Get all models. - * Secret settings are not included - * @param listener Models listener - */ - void getAllModels(ActionListener> listener); - - void storeModel(Model model, ActionListener listener); - - void deleteModel(String modelId, ActionListener listener); - - /** - * Semi parsed model where inference entity id, task type and service - * are known but the settings are not parsed. - */ - record UnparsedModel( - String inferenceEntityId, - TaskType taskType, - String service, - Map settings, - Map secrets - ) {} - - class NoopModelRegistry implements ModelRegistry { - @Override - public void getModel(String modelId, ActionListener listener) { - fail(listener); - } - - @Override - public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { - listener.onResponse(List.of()); - } - - @Override - public void getAllModels(ActionListener> listener) { - listener.onResponse(List.of()); - } - - @Override - public void storeModel(Model model, ActionListener listener) { - fail(listener); - } - - @Override - public void deleteModel(String modelId, ActionListener listener) { - fail(listener); - } - - @Override - public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { - fail(listener); - } - - private static void fail(ActionListener listener) { - listener.onFailure(new IllegalArgumentException("No model registry configured")); - } - } -} diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index 15ebe2752451d..5bf19c4b87157 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -127,8 +127,6 @@ import org.elasticsearch.indices.recovery.plan.PeerOnlyRecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.RecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.ShardSnapshotsService; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.monitor.MonitorService; import org.elasticsearch.monitor.fs.FsHealthService; @@ -147,7 +145,6 @@ import org.elasticsearch.plugins.ClusterPlugin; import org.elasticsearch.plugins.DiscoveryPlugin; import org.elasticsearch.plugins.HealthPlugin; -import org.elasticsearch.plugins.InferenceRegistryPlugin; import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.MetadataUpgrader; @@ -1114,18 +1111,6 @@ record PluginServiceInstances( ); } - // Register noop versions of inference services if Inference plugin is not available - Optional inferenceRegistryPlugin = getSinglePlugin(InferenceRegistryPlugin.class); - modules.bindToInstance( - InferenceServiceRegistry.class, - inferenceRegistryPlugin.map(InferenceRegistryPlugin::getInferenceServiceRegistry) - .orElse(new InferenceServiceRegistry.NoopInferenceServiceRegistry()) - ); - modules.bindToInstance( - ModelRegistry.class, - inferenceRegistryPlugin.map(InferenceRegistryPlugin::getModelRegistry).orElse(new ModelRegistry.NoopModelRegistry()) - ); - injector = modules.createInjector(); postInjection(clusterModule, actionModule, clusterService, transportService, featureService); diff --git a/server/src/main/java/org/elasticsearch/plugins/InferenceRegistryPlugin.java b/server/src/main/java/org/elasticsearch/plugins/InferenceRegistryPlugin.java deleted file mode 100644 index 696c3a067dad1..0000000000000 --- a/server/src/main/java/org/elasticsearch/plugins/InferenceRegistryPlugin.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.plugins; - -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; - -/** - * Plugins that provide inference services should implement this interface. - * There should be a single one in the classpath, as we currently support a single instance for ModelRegistry / InfereceServiceRegistry. - */ -public interface InferenceRegistryPlugin { - InferenceServiceRegistry getInferenceServiceRegistry(); - - ModelRegistry getModelRegistry(); -} diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java deleted file mode 100644 index c3887f506b891..0000000000000 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ /dev/null @@ -1,670 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.action.bulk; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.DocWriteRequest; -import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.index.IndexResponse; -import org.elasticsearch.client.internal.node.NodeClient; -import org.elasticsearch.cluster.ClusterName; -import org.elasticsearch.cluster.ClusterState; -import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; -import org.elasticsearch.cluster.metadata.IndexAbstraction; -import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; -import org.elasticsearch.cluster.metadata.Metadata; -import org.elasticsearch.cluster.node.DiscoveryNodeUtils; -import org.elasticsearch.cluster.service.ClusterApplierService; -import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.AtomicArray; -import org.elasticsearch.core.Tuple; -import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelRegistry; -import org.elasticsearch.inference.SemanticTextModelSettings; -import org.elasticsearch.inference.ServiceSettings; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.tasks.Task; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.mockito.ArgumentCaptor; -import org.mockito.ArgumentMatcher; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; -import java.util.function.Function; -import java.util.stream.Collectors; - -import static java.util.Collections.emptyMap; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_RESULTS; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD; -import static org.hamcrest.CoreMatchers.containsString; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyList; -import static org.mockito.ArgumentMatchers.anyMap; -import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; - -public class BulkOperationTests extends ESTestCase { - - private static final String INDEX_NAME = "test-index"; - private static final String INFERENCE_SERVICE_1_ID = "inference_service_1_id"; - private static final String INFERENCE_SERVICE_2_ID = "inference_service_2_id"; - private static final String FIRST_INFERENCE_FIELD_SERVICE_1 = "first_inference_field_service_1"; - private static final String SECOND_INFERENCE_FIELD_SERVICE_1 = "second_inference_field_service_1"; - private static final String INFERENCE_FIELD_SERVICE_2 = "inference_field_service_2"; - private static final String SERVICE_1_ID = "elser_v2"; - private static final String SERVICE_2_ID = "e5"; - private static final String INFERENCE_FAILED_MSG = "Inference failed"; - private static TestThreadPool threadPool; - - public void testNoInference() { - - FieldInferenceMetadata fieldInferenceMetadata = FieldInferenceMetadata.EMPTY; - ModelRegistry modelRegistry = createModelRegistry( - Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) - ); - - Model model1 = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService1 = createInferenceService(model1); - Model model2 = mockModel(INFERENCE_SERVICE_2_ID); - InferenceService inferenceService2 = createInferenceService(model2); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry( - Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) - ); - - Map originalSource = Map.of( - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100), - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - BulkShardRequest bulkShardRequest = runBulkOperation( - originalSource, - fieldInferenceMetadata, - modelRegistry, - inferenceServiceRegistry, - true, - bulkOperationListener - ); - verify(bulkOperationListener).onResponse(any()); - - BulkItemRequest[] items = bulkShardRequest.items(); - assertThat(items.length, equalTo(1)); - - Map writtenDocSource = ((IndexRequest) items[0].request()).sourceAsMap(); - // Original doc source is preserved - originalSource.forEach((key, value) -> assertThat(writtenDocSource.get(key), equalTo(value))); - - // Check inference not invoked - verifyNoMoreInteractions(modelRegistry); - verifyNoMoreInteractions(inferenceServiceRegistry); - } - - private static Model mockModel(String inferenceServiceId) { - Model model = mock(Model.class); - - when(model.getInferenceEntityId()).thenReturn(inferenceServiceId); - TaskType taskType = randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING; - when(model.getTaskType()).thenReturn(taskType); - - ServiceSettings serviceSettings = mock(ServiceSettings.class); - when(model.getServiceSettings()).thenReturn(serviceSettings); - SimilarityMeasure similarity = switch (randomInt(2)) { - case 0 -> SimilarityMeasure.COSINE; - case 1 -> SimilarityMeasure.DOT_PRODUCT; - default -> null; - }; - when(serviceSettings.similarity()).thenReturn(similarity); - when(serviceSettings.dimensions()).thenReturn(randomBoolean() ? null : randomIntBetween(1, 1000)); - - return model; - } - - public void testFailedBulkShardRequest() { - - FieldInferenceMetadata fieldInferenceMetadata = FieldInferenceMetadata.EMPTY; - ModelRegistry modelRegistry = createModelRegistry(Map.of()); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of()); - - Map originalSource = Map.of( - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100), - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); - doAnswer(invocation -> null).when(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - - runBulkOperation( - originalSource, - fieldInferenceMetadata, - modelRegistry, - inferenceServiceRegistry, - bulkOperationListener, - true, - request -> new BulkShardResponse( - request.shardId(), - new BulkItemResponse[] { - BulkItemResponse.failure( - 0, - DocWriteRequest.OpType.INDEX, - new BulkItemResponse.Failure( - INDEX_NAME, - randomIdentifier(), - new IllegalArgumentException("Error on bulk shard request") - ) - ) } - ) - ); - verify(bulkOperationListener).onResponse(any()); - - BulkResponse bulkResponse = bulkResponseCaptor.getValue(); - assertTrue(bulkResponse.hasFailures()); - BulkItemResponse[] items = bulkResponse.getItems(); - assertTrue(items[0].isFailed()); - } - - @SuppressWarnings("unchecked") - public void testInference() { - - FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( - Map.of( - FIRST_INFERENCE_FIELD_SERVICE_1, - new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), - SECOND_INFERENCE_FIELD_SERVICE_1, - new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), - INFERENCE_FIELD_SERVICE_2, - new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_2_ID, Set.of()) - ) - ); - - ModelRegistry modelRegistry = createModelRegistry( - Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID, INFERENCE_SERVICE_2_ID, SERVICE_2_ID) - ); - - Model model1 = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService1 = createInferenceService(model1); - Model model2 = mockModel(INFERENCE_SERVICE_2_ID); - InferenceService inferenceService2 = createInferenceService(model2); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry( - Map.of(SERVICE_1_ID, inferenceService1, SERVICE_2_ID, inferenceService2) - ); - - String firstInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); - String secondInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); - String inferenceTextService2 = randomAlphaOfLengthBetween(1, 100); - Map originalSource = Map.of( - FIRST_INFERENCE_FIELD_SERVICE_1, - firstInferenceTextService1, - SECOND_INFERENCE_FIELD_SERVICE_1, - secondInferenceTextService1, - INFERENCE_FIELD_SERVICE_2, - inferenceTextService2, - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100), - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - ActionListener bulkOperationListener = mock(ActionListener.class); - BulkShardRequest bulkShardRequest = runBulkOperation( - originalSource, - fieldInferenceMetadata, - modelRegistry, - inferenceServiceRegistry, - true, - bulkOperationListener - ); - verify(bulkOperationListener).onResponse(any()); - - BulkItemRequest[] items = bulkShardRequest.items(); - assertThat(items.length, equalTo(1)); - - Map writtenDocSource = ((IndexRequest) items[0].request()).sourceAsMap(); - // Original doc source is preserved - originalSource.forEach((key, value) -> assertThat(writtenDocSource.get(key), equalTo(value))); - - // Check inference results - verifyInferenceServiceInvoked( - modelRegistry, - INFERENCE_SERVICE_1_ID, - inferenceService1, - model1, - List.of(firstInferenceTextService1, secondInferenceTextService1) - ); - verifyInferenceServiceInvoked(modelRegistry, INFERENCE_SERVICE_2_ID, inferenceService2, model2, List.of(inferenceTextService2)); - checkInferenceResults( - originalSource, - writtenDocSource, - FIRST_INFERENCE_FIELD_SERVICE_1, - SECOND_INFERENCE_FIELD_SERVICE_1, - INFERENCE_FIELD_SERVICE_2 - ); - } - - public void testFailedInference() { - - FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( - Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of())) - ); - - ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - - Model model = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService = createInferenceServiceThatFails(model); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); - - String firstInferenceTextService1 = randomAlphaOfLengthBetween(1, 100); - Map originalSource = Map.of( - FIRST_INFERENCE_FIELD_SERVICE_1, - firstInferenceTextService1, - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - runBulkOperation(originalSource, fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); - - verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - BulkResponse bulkResponse = bulkResponseCaptor.getValue(); - assertTrue(bulkResponse.hasFailures()); - BulkItemResponse item = bulkResponse.getItems()[0]; - assertTrue(item.isFailed()); - assertThat(item.getFailure().getCause().getMessage(), equalTo(INFERENCE_FAILED_MSG)); - - verifyInferenceServiceInvoked(modelRegistry, INFERENCE_SERVICE_1_ID, inferenceService, model, List.of(firstInferenceTextService1)); - - } - - public void testInferenceFailsForIncorrectRootObject() { - - FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( - Map.of(FIRST_INFERENCE_FIELD_SERVICE_1, new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of())) - ); - - ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - - Model model = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService = createInferenceServiceThatFails(model); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); - - Map originalSource = Map.of( - FIRST_INFERENCE_FIELD_SERVICE_1, - randomAlphaOfLengthBetween(1, 100), - ROOT_INFERENCE_FIELD, - "incorrect_root_object" - ); - - ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - runBulkOperation(originalSource, fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); - - verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - BulkResponse bulkResponse = bulkResponseCaptor.getValue(); - assertTrue(bulkResponse.hasFailures()); - BulkItemResponse item = bulkResponse.getItems()[0]; - assertTrue(item.isFailed()); - assertThat(item.getFailure().getCause().getMessage(), containsString("[_semantic_text_inference] is not an object")); - } - - public void testInferenceIdNotFound() { - - FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata( - Map.of( - FIRST_INFERENCE_FIELD_SERVICE_1, - new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), - SECOND_INFERENCE_FIELD_SERVICE_1, - new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_1_ID, Set.of()), - INFERENCE_FIELD_SERVICE_2, - new FieldInferenceMetadata.FieldInferenceOptions(INFERENCE_SERVICE_2_ID, Set.of()) - ) - ); - - ModelRegistry modelRegistry = createModelRegistry(Map.of(INFERENCE_SERVICE_1_ID, SERVICE_1_ID)); - - Model model = mockModel(INFERENCE_SERVICE_1_ID); - InferenceService inferenceService = createInferenceService(model); - InferenceServiceRegistry inferenceServiceRegistry = createInferenceServiceRegistry(Map.of(SERVICE_1_ID, inferenceService)); - - Map originalSource = Map.of( - INFERENCE_FIELD_SERVICE_2, - randomAlphaOfLengthBetween(1, 100), - randomAlphaOfLengthBetween(1, 20), - randomAlphaOfLengthBetween(1, 100) - ); - - ArgumentCaptor bulkResponseCaptor = ArgumentCaptor.forClass(BulkResponse.class); - @SuppressWarnings("unchecked") - ActionListener bulkOperationListener = mock(ActionListener.class); - doAnswer(invocation -> null).when(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - - runBulkOperation(originalSource, fieldInferenceMetadata, modelRegistry, inferenceServiceRegistry, false, bulkOperationListener); - - verify(bulkOperationListener).onResponse(bulkResponseCaptor.capture()); - BulkResponse bulkResponse = bulkResponseCaptor.getValue(); - assertTrue(bulkResponse.hasFailures()); - BulkItemResponse item = bulkResponse.getItems()[0]; - assertTrue(item.isFailed()); - assertThat( - item.getFailure().getCause().getMessage(), - equalTo("No inference provider found for model ID " + INFERENCE_SERVICE_2_ID) - ); - } - - @SuppressWarnings("unchecked") - private static void checkInferenceResults( - Map docSource, - Map writtenDocSource, - String... inferenceFieldNames - ) { - - Map inferenceRootResultField = (Map) writtenDocSource.get( - BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD - ); - - for (String inferenceFieldName : inferenceFieldNames) { - Map inferenceService1FieldResults = (Map) inferenceRootResultField.get(inferenceFieldName); - assertNotNull(inferenceService1FieldResults); - assertThat(inferenceService1FieldResults.size(), equalTo(2)); - Map modelSettings = (Map) inferenceService1FieldResults.get(SemanticTextModelSettings.NAME); - assertNotNull(modelSettings); - assertNotNull(modelSettings.get(SemanticTextModelSettings.TASK_TYPE_FIELD.getPreferredName())); - assertNotNull(modelSettings.get(SemanticTextModelSettings.INFERENCE_ID_FIELD.getPreferredName())); - - List> inferenceResultElement = (List>) inferenceService1FieldResults.get( - INFERENCE_RESULTS - ); - assertFalse(inferenceResultElement.isEmpty()); - assertNotNull(inferenceResultElement.get(0).get(INFERENCE_CHUNKS_RESULTS)); - assertThat(inferenceResultElement.get(0).get(INFERENCE_CHUNKS_TEXT), equalTo(docSource.get(inferenceFieldName))); - } - } - - private static void verifyInferenceServiceInvoked( - ModelRegistry modelRegistry, - String inferenceService1Id, - InferenceService inferenceService, - Model model, - Collection inferenceTexts - ) { - verify(modelRegistry).getModelWithSecrets(eq(inferenceService1Id), any()); - verify(inferenceService).parsePersistedConfigWithSecrets( - eq(inferenceService1Id), - eq(TaskType.SPARSE_EMBEDDING), - anyMap(), - anyMap() - ); - verify(inferenceService).infer(eq(model), argThat(containsInAnyOrder(inferenceTexts)), anyMap(), eq(InputType.INGEST), any()); - verifyNoMoreInteractions(inferenceService); - } - - private static ArgumentMatcher> containsInAnyOrder(Collection expected) { - return new ArgumentMatcher<>() { - @Override - public boolean matches(List argument) { - return argument.containsAll(expected) && argument.size() == expected.size(); - } - - @Override - public String toString() { - return "containsAll(" + expected.stream().collect(Collectors.joining(", ")) + ")"; - } - }; - } - - private static BulkShardRequest runBulkOperation( - Map docSource, - FieldInferenceMetadata fieldInferenceMetadata, - ModelRegistry modelRegistry, - InferenceServiceRegistry inferenceServiceRegistry, - boolean expectTransportShardBulkActionToExecute, - ActionListener bulkOperationListener - ) { - return runBulkOperation( - docSource, - fieldInferenceMetadata, - modelRegistry, - inferenceServiceRegistry, - bulkOperationListener, - expectTransportShardBulkActionToExecute, - successfulBulkShardResponse - ); - } - - private static BulkShardRequest runBulkOperation( - Map docSource, - FieldInferenceMetadata fieldInferenceMetadata, - ModelRegistry modelRegistry, - InferenceServiceRegistry inferenceServiceRegistry, - ActionListener bulkOperationListener, - boolean expectTransportShardBulkActionToExecute, - Function bulkShardResponseSupplier - ) { - Settings settings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build(); - IndexMetadata indexMetadata = IndexMetadata.builder(INDEX_NAME) - .fieldInferenceMetadata(fieldInferenceMetadata) - .settings(settings) - .numberOfShards(1) - .numberOfReplicas(0) - .build(); - ClusterService clusterService = createClusterService(indexMetadata); - - IndexNameExpressionResolver indexResolver = mock(IndexNameExpressionResolver.class); - when(indexResolver.resolveWriteIndexAbstraction(any(), any())).thenReturn(new IndexAbstraction.ConcreteIndex(indexMetadata)); - - BulkRequest bulkRequest = new BulkRequest(); - bulkRequest.add(new IndexRequest(INDEX_NAME).source(docSource)); - - NodeClient client = mock(NodeClient.class); - - ArgumentCaptor bulkShardRequestCaptor = ArgumentCaptor.forClass(BulkShardRequest.class); - doAnswer(invocation -> { - BulkShardRequest request = invocation.getArgument(1); - ActionListener bulkShardResponseListener = invocation.getArgument(2); - bulkShardResponseListener.onResponse(bulkShardResponseSupplier.apply(request)); - return null; - }).when(client).executeLocally(eq(TransportShardBulkAction.TYPE), bulkShardRequestCaptor.capture(), any()); - - Task task = new Task(randomLong(), "transport", "action", "", null, emptyMap()); - BulkOperation bulkOperation = new BulkOperation( - task, - threadPool, - ThreadPool.Names.WRITE, - clusterService, - bulkRequest, - client, - new AtomicArray<>(bulkRequest.requests.size()), - new HashMap<>(), - indexResolver, - () -> System.nanoTime(), - System.nanoTime(), - modelRegistry, - inferenceServiceRegistry, - bulkOperationListener - ); - - bulkOperation.doRun(); - if (expectTransportShardBulkActionToExecute) { - verify(client).executeLocally(eq(TransportShardBulkAction.TYPE), any(), any()); - return bulkShardRequestCaptor.getValue(); - } - - return null; - } - - private static final Function successfulBulkShardResponse = (request) -> { - return new BulkShardResponse( - request.shardId(), - Arrays.stream(request.items()) - .filter(Objects::nonNull) - .map( - item -> BulkItemResponse.success( - item.id(), - DocWriteRequest.OpType.INDEX, - new IndexResponse(request.shardId(), randomIdentifier(), randomLong(), randomLong(), randomLong(), randomBoolean()) - ) - ) - .toArray(BulkItemResponse[]::new) - ); - }; - - private static InferenceService createInferenceService(Model model) { - InferenceService inferenceService = mock(InferenceService.class); - when( - inferenceService.parsePersistedConfigWithSecrets( - eq(model.getInferenceEntityId()), - eq(TaskType.SPARSE_EMBEDDING), - anyMap(), - anyMap() - ) - ).thenReturn(model); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(4); - InferenceServiceResults inferenceServiceResults = mock(InferenceServiceResults.class); - List texts = invocation.getArgument(1); - List inferenceResults = new ArrayList<>(); - for (int i = 0; i < texts.size(); i++) { - inferenceResults.add(createInferenceResults()); - } - doReturn(inferenceResults).when(inferenceServiceResults).transformToCoordinationFormat(); - - listener.onResponse(inferenceServiceResults); - return null; - }).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any()); - return inferenceService; - } - - private static InferenceService createInferenceServiceThatFails(Model model) { - InferenceService inferenceService = mock(InferenceService.class); - when( - inferenceService.parsePersistedConfigWithSecrets( - eq(model.getInferenceEntityId()), - eq(TaskType.SPARSE_EMBEDDING), - anyMap(), - anyMap() - ) - ).thenReturn(model); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(4); - listener.onFailure(new IllegalArgumentException(INFERENCE_FAILED_MSG)); - return null; - }).when(inferenceService).infer(eq(model), anyList(), anyMap(), eq(InputType.INGEST), any()); - return inferenceService; - } - - private static InferenceResults createInferenceResults() { - InferenceResults inferenceResults = mock(InferenceResults.class); - when(inferenceResults.asMap(any())).then( - invocation -> Map.of( - (String) invocation.getArguments()[0], - Map.of("sparse_embedding", randomMap(1, 10, () -> new Tuple<>(randomAlphaOfLength(10), randomFloat()))) - ) - ); - return inferenceResults; - } - - private static InferenceServiceRegistry createInferenceServiceRegistry(Map inferenceServices) { - InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); - inferenceServices.forEach((id, service) -> when(inferenceServiceRegistry.getService(id)).thenReturn(Optional.of(service))); - return inferenceServiceRegistry; - } - - private static ModelRegistry createModelRegistry(Map inferenceIdsToServiceIds) { - ModelRegistry modelRegistry = mock(ModelRegistry.class); - // Fails for unknown inference ids - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new IllegalArgumentException("Model not found")); - return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); - inferenceIdsToServiceIds.forEach((inferenceId, serviceId) -> { - ModelRegistry.UnparsedModel unparsedModel = new ModelRegistry.UnparsedModel( - inferenceId, - TaskType.SPARSE_EMBEDDING, - serviceId, - emptyMap(), - emptyMap() - ); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(unparsedModel); - return null; - }).when(modelRegistry).getModelWithSecrets(eq(inferenceId), any()); - }); - - return modelRegistry; - } - - private static ClusterService createClusterService(IndexMetadata indexMetadata) { - Metadata metadata = Metadata.builder().indices(Map.of(INDEX_NAME, indexMetadata)).build(); - - ClusterService clusterService = mock(ClusterService.class); - when(clusterService.localNode()).thenReturn(DiscoveryNodeUtils.create(randomIdentifier())); - - ClusterState clusterState = ClusterState.builder(ClusterName.DEFAULT).metadata(metadata).version(randomNonNegativeLong()).build(); - when(clusterService.state()).thenReturn(clusterState); - - ClusterApplierService clusterApplierService = mock(ClusterApplierService.class); - when(clusterApplierService.state()).thenReturn(clusterState); - when(clusterApplierService.threadPool()).thenReturn(threadPool); - when(clusterService.getClusterApplierService()).thenReturn(clusterApplierService); - return clusterService; - } - - @BeforeClass - public static void createThreadPool() { - threadPool = new TestThreadPool(getTestClass().getName()); - } - - @AfterClass - public static void stopThreadPool() { - if (threadPool != null) { - threadPool.shutdownNow(); - threadPool = null; - } - } - -} diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java index 988a92352649a..3057b00553a22 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java @@ -129,19 +129,17 @@ public boolean hasIndexAbstraction(String indexAbstraction, ClusterState state) mock(ActionFilters.class), indexNameExpressionResolver, new IndexingPressure(Settings.EMPTY), - EmptySystemIndices.INSTANCE, - null, - null + EmptySystemIndices.INSTANCE ) { @Override void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { assertEquals(expected, indicesThatCannotBeCreated.keySet()); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java index 2d6492e4e73a4..6815d634292a4 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java @@ -148,9 +148,7 @@ class TestTransportBulkAction extends TransportBulkAction { new ActionFilters(Collections.emptySet()), TestIndexNameExpressionResolver.newInstance(), new IndexingPressure(SETTINGS), - EmptySystemIndices.INSTANCE, - null, - null + EmptySystemIndices.INSTANCE ); } @@ -159,10 +157,10 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { assertTrue(indexCreated); isExecuted = true; diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java index ad522e36f9bd9..1a16d9083df55 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java @@ -98,9 +98,7 @@ class TestTransportBulkAction extends TransportBulkAction { new ActionFilters(Collections.emptySet()), new Resolver(), new IndexingPressure(Settings.EMPTY), - EmptySystemIndices.INSTANCE, - null, - null + EmptySystemIndices.INSTANCE ); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java index a2e54a1c7c3b8..cb9bdd1f3a827 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java @@ -139,13 +139,13 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { expected.set(1000000); - super.executeBulk(task, bulkRequest, startTimeNanos, executorName, responses, indicesThatCannotBeCreated, listener); + super.executeBulk(task, bulkRequest, startTimeNanos, listener, executorName, responses, indicesThatCannotBeCreated); } }; } else { @@ -164,14 +164,14 @@ void executeBulk( Task task, BulkRequest bulkRequest, long startTimeNanos, + ActionListener listener, String executorName, AtomicArray responses, - Map indicesThatCannotBeCreated, - ActionListener listener + Map indicesThatCannotBeCreated ) { long elapsed = spinForAtLeastOneMillisecond(); expected.set(elapsed); - super.executeBulk(task, bulkRequest, startTimeNanos, executorName, responses, indicesThatCannotBeCreated, listener); + super.executeBulk(task, bulkRequest, startTimeNanos, listener, executorName, responses, indicesThatCannotBeCreated); } }; } @@ -253,9 +253,7 @@ static class TestTransportBulkAction extends TransportBulkAction { indexNameExpressionResolver, new IndexingPressure(Settings.EMPTY), EmptySystemIndices.INSTANCE, - relativeTimeProvider, - null, - null + relativeTimeProvider ); } } diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index 7f1b5cdaee598..0a53db94b9aaf 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -2360,9 +2360,7 @@ protected void assertSnapshotOrGenericThread() { actionFilters, indexNameExpressionResolver, new IndexingPressure(settings), - EmptySystemIndices.INSTANCE, - null, - null + EmptySystemIndices.INSTANCE ) ); final TransportShardBulkAction transportShardBulkAction = new TransportShardBulkAction( diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 33bbc94901e9d..b6e48d3b1c29a 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -123,15 +123,17 @@ private SparseEmbeddingResults makeResults(List input) { } private List makeChunkedResults(List input) { - var chunks = new ArrayList(); + List results = new ArrayList<>(); for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); for (int j = 0; j < 5; j++) { tokens.add(new TextExpansionResults.WeightedToken("feature_" + j, j + 1.0F)); } - chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens)); + results.add( + new ChunkedSparseEmbeddingResults(List.of(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens))) + ); } - return List.of(new ChunkedSparseEmbeddingResults(chunks)); + return results; } protected ServiceSettings getServiceSettingsFromMap(Map serviceSettingsMap) { diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryImplIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java similarity index 86% rename from x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryImplIT.java rename to x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index ccda986a8d280..0f23e0b33d774 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryImplIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -26,7 +26,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel; import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettingsTests; @@ -55,13 +55,13 @@ import static org.hamcrest.Matchers.nullValue; import static org.mockito.Mockito.mock; -public class ModelRegistryImplIT extends ESSingleNodeTestCase { +public class ModelRegistryIT extends ESSingleNodeTestCase { - private ModelRegistryImpl ModelRegistryImpl; + private ModelRegistry modelRegistry; @Before public void createComponents() { - ModelRegistryImpl = new ModelRegistryImpl(client()); + modelRegistry = new ModelRegistry(client()); } @Override @@ -75,7 +75,7 @@ public void testStoreModel() throws Exception { AtomicReference storeModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), storeModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder); assertThat(storeModelHolder.get(), is(true)); assertThat(exceptionHolder.get(), is(nullValue())); @@ -87,7 +87,7 @@ public void testStoreModelWithUnknownFields() throws Exception { AtomicReference storeModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), storeModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder); assertNull(storeModelHolder.get()); assertNotNull(exceptionHolder.get()); @@ -106,12 +106,12 @@ public void testGetModel() throws Exception { AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); // now get the model - AtomicReference modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(exceptionHolder.get(), is(nullValue())); assertThat(modelHolder.get(), not(nullValue())); @@ -133,13 +133,13 @@ public void testStoreModelFailsWhenModelExists() throws Exception { AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); assertThat(exceptionHolder.get(), is(nullValue())); putModelHolder.set(false); // an model with the same id exists - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(false)); assertThat(exceptionHolder.get(), not(nullValue())); assertThat( @@ -154,20 +154,20 @@ public void testDeleteModel() throws Exception { Model model = buildElserModelConfig(id, TaskType.SPARSE_EMBEDDING); AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); } AtomicReference deleteResponseHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.deleteModel("model1", listener), deleteResponseHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.deleteModel("model1", listener), deleteResponseHolder, exceptionHolder); assertThat(exceptionHolder.get(), is(nullValue())); assertTrue(deleteResponseHolder.get()); // get should fail deleteResponseHolder.set(false); - AtomicReference modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getModelWithSecrets("model1", listener), modelHolder, exceptionHolder); + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelWithSecrets("model1", listener), modelHolder, exceptionHolder); assertThat(exceptionHolder.get(), not(nullValue())); assertFalse(deleteResponseHolder.get()); @@ -187,13 +187,13 @@ public void testGetModelsByTaskType() throws InterruptedException { AtomicReference putModelHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); } AtomicReference exceptionHolder = new AtomicReference<>(); - AtomicReference> modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder); + AtomicReference> modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get(), hasSize(3)); var sparseIds = sparseAndTextEmbeddingModels.stream() .filter(m -> m.getConfigurations().getTaskType() == TaskType.SPARSE_EMBEDDING) @@ -204,7 +204,7 @@ public void testGetModelsByTaskType() throws InterruptedException { assertThat(m.secrets().keySet(), empty()); }); - blockingCall(listener -> ModelRegistryImpl.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener), modelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get(), hasSize(2)); var denseIds = sparseAndTextEmbeddingModels.stream() .filter(m -> m.getConfigurations().getTaskType() == TaskType.TEXT_EMBEDDING) @@ -228,13 +228,13 @@ public void testGetAllModels() throws InterruptedException { var model = createModel(randomAlphaOfLength(5), randomFrom(TaskType.values()), service); createdModels.add(model); - blockingCall(listener -> ModelRegistryImpl.storeModel(model, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); assertNull(exceptionHolder.get()); } - AtomicReference> modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getAllModels(listener), modelHolder, exceptionHolder); + AtomicReference> modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder); assertThat(modelHolder.get(), hasSize(modelCount)); var getAllModels = modelHolder.get(); @@ -258,18 +258,18 @@ public void testGetModelWithSecrets() throws InterruptedException { AtomicReference exceptionHolder = new AtomicReference<>(); var modelWithSecrets = createModelWithSecrets(inferenceEntityId, randomFrom(TaskType.values()), service, secret); - blockingCall(listener -> ModelRegistryImpl.storeModel(modelWithSecrets, listener), putModelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.storeModel(modelWithSecrets, listener), putModelHolder, exceptionHolder); assertThat(putModelHolder.get(), is(true)); assertNull(exceptionHolder.get()); - AtomicReference modelHolder = new AtomicReference<>(); - blockingCall(listener -> ModelRegistryImpl.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get().secrets().keySet(), hasSize(1)); var secretSettings = (Map) modelHolder.get().secrets().get("secret_settings"); assertThat(secretSettings.get("secret"), equalTo(secret)); // get model without secrets - blockingCall(listener -> ModelRegistryImpl.getModel(inferenceEntityId, listener), modelHolder, exceptionHolder); + blockingCall(listener -> modelRegistry.getModel(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get().secrets().keySet(), empty()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 31aae67770c98..2a9c300e12c13 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -10,6 +10,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.support.ActionFilter; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -26,11 +27,8 @@ import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.InferenceServiceRegistryImpl; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; -import org.elasticsearch.plugins.InferenceRegistryPlugin; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SystemIndexPlugin; @@ -49,6 +47,7 @@ import org.elasticsearch.xpack.inference.action.TransportInferenceAction; import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction; import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; +import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpSettings; @@ -56,9 +55,9 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; -import org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper; -import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestInferenceAction; @@ -80,13 +79,9 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -public class InferencePlugin extends Plugin - implements - ActionPlugin, - ExtensiblePlugin, - SystemIndexPlugin, - InferenceRegistryPlugin, - MapperPlugin { +import static java.util.Collections.singletonList; + +public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin, MapperPlugin { /** * When this setting is true the verification check that @@ -111,8 +106,7 @@ public class InferencePlugin extends Plugin private final SetOnce serviceComponents = new SetOnce<>(); private final SetOnce inferenceServiceRegistry = new SetOnce<>(); - private final SetOnce modelRegistry = new SetOnce<>(); - + private final SetOnce shardBulkInferenceActionFilter = new SetOnce<>(); private List inferenceServiceExtensions; public InferencePlugin(Settings settings) { @@ -163,7 +157,7 @@ public Collection createComponents(PluginServices services) { ); httpFactory.set(httpRequestSenderFactory); - ModelRegistry modelReg = new ModelRegistryImpl(services.client()); + ModelRegistry modelRegistry = new ModelRegistry(services.client()); if (inferenceServiceExtensions == null) { inferenceServiceExtensions = new ArrayList<>(); @@ -174,13 +168,14 @@ public Collection createComponents(PluginServices services) { var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client()); // This must be done after the HttpRequestSenderFactory is created so that the services can get the // reference correctly - var inferenceRegistry = new InferenceServiceRegistryImpl(inferenceServices, factoryContext); - inferenceRegistry.init(services.client()); - inferenceServiceRegistry.set(inferenceRegistry); - modelRegistry.set(modelReg); + var registry = new InferenceServiceRegistry(inferenceServices, factoryContext); + registry.init(services.client()); + inferenceServiceRegistry.set(registry); + + var actionFilter = new ShardBulkInferenceActionFilter(registry, modelRegistry); + shardBulkInferenceActionFilter.set(actionFilter); - // Don't return components as they will be registered using InferenceRegistryPlugin methods to retrieve them - return List.of(); + return List.of(modelRegistry, registry); } @Override @@ -279,16 +274,6 @@ public void close() { IOUtils.closeWhileHandlingException(inferenceServiceRegistry.get(), throttlerToClose); } - @Override - public InferenceServiceRegistry getInferenceServiceRegistry() { - return inferenceServiceRegistry.get(); - } - - @Override - public ModelRegistry getModelRegistry() { - return modelRegistry.get(); - } - @Override public Map getMappers() { if (SemanticTextFeature.isEnabled()) { @@ -299,6 +284,11 @@ public Map getMappers() { @Override public Map getMetadataMappers() { - return Map.of(SemanticTextInferenceResultFieldMapper.NAME, SemanticTextInferenceResultFieldMapper.PARSER); + return Map.of(InferenceResultFieldMapper.NAME, InferenceResultFieldMapper.PARSER); + } + + @Override + public Collection getActionFilters() { + return singletonList(shardBulkInferenceActionFilter.get()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java index ad6042581f264..b55e2e6f8ebed 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java @@ -23,12 +23,12 @@ import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.DeleteInferenceModelAction; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; public class TransportDeleteInferenceModelAction extends AcknowledgedTransportMasterNodeAction { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java index 0f7e48c4f8140..2de1aecea118c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java @@ -17,7 +17,6 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -25,6 +24,7 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; import java.util.List; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index ece4fee1c935f..fb3974fc12e8b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -16,11 +16,11 @@ import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; public class TransportInferenceAction extends HandledTransportAction { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 6667e314a62b8..07d28f8e5b0a8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -29,7 +29,6 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -44,6 +43,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.io.IOException; import java.util.Map; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java new file mode 100644 index 0000000000000..fbf84762eb314 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -0,0 +1,343 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action.filter; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.bulk.BulkItemRequest; +import org.elasticsearch.action.bulk.BulkShardRequest; +import org.elasticsearch.action.bulk.TransportShardBulkAction; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.ActionFilter; +import org.elasticsearch.action.support.ActionFilterChain; +import org.elasticsearch.action.support.MappedActionFilter; +import org.elasticsearch.action.support.RefCountingRunnable; +import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; +import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * An {@link ActionFilter} that performs inference on {@link BulkShardRequest} asynchronously and stores the results in + * the individual {@link BulkItemRequest}. The results are then consumed by the {@link InferenceResultFieldMapper} + * in the subsequent {@link TransportShardBulkAction} downstream. + */ +public class ShardBulkInferenceActionFilter implements MappedActionFilter { + private static final Logger logger = LogManager.getLogger(ShardBulkInferenceActionFilter.class); + + private final InferenceServiceRegistry inferenceServiceRegistry; + private final ModelRegistry modelRegistry; + + public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry) { + this.inferenceServiceRegistry = inferenceServiceRegistry; + this.modelRegistry = modelRegistry; + } + + @Override + public int order() { + // must execute last (after the security action filter) + return Integer.MAX_VALUE; + } + + @Override + public String actionName() { + return TransportShardBulkAction.ACTION_NAME; + } + + @Override + public void apply( + Task task, + String action, + Request request, + ActionListener listener, + ActionFilterChain chain + ) { + switch (action) { + case TransportShardBulkAction.ACTION_NAME: + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + var fieldInferenceMetadata = bulkShardRequest.consumeFieldInferenceMetadata(); + if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) { + Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener); + processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion); + } else { + chain.proceed(task, action, request, listener); + } + break; + + default: + chain.proceed(task, action, request, listener); + break; + } + } + + private void processBulkShardRequest( + FieldInferenceMetadata fieldInferenceMetadata, + BulkShardRequest bulkShardRequest, + Runnable onCompletion + ) { + new AsyncBulkShardInferenceAction(fieldInferenceMetadata, bulkShardRequest, onCompletion).run(); + } + + private record InferenceProvider(InferenceService service, Model model) {} + + private record FieldInferenceRequest(int id, String field, String input) {} + + private record FieldInferenceResponse(String field, Model model, ChunkedInferenceServiceResults chunkedResults) {} + + private record FieldInferenceResponseAccumulator(int id, List responses, List failures) {} + + private class AsyncBulkShardInferenceAction implements Runnable { + private final FieldInferenceMetadata fieldInferenceMetadata; + private final BulkShardRequest bulkShardRequest; + private final Runnable onCompletion; + private final AtomicArray inferenceResults; + + private AsyncBulkShardInferenceAction( + FieldInferenceMetadata fieldInferenceMetadata, + BulkShardRequest bulkShardRequest, + Runnable onCompletion + ) { + this.fieldInferenceMetadata = fieldInferenceMetadata; + this.bulkShardRequest = bulkShardRequest; + this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length); + this.onCompletion = onCompletion; + } + + @Override + public void run() { + Map> inferenceRequests = createFieldInferenceRequests(bulkShardRequest); + Runnable onInferenceCompletion = () -> { + try { + for (var inferenceResponse : inferenceResults.asList()) { + var request = bulkShardRequest.items()[inferenceResponse.id]; + try { + applyInferenceResponses(request, inferenceResponse); + } catch (Exception exc) { + request.abort(bulkShardRequest.index(), exc); + } + } + } finally { + onCompletion.run(); + } + }; + try (var releaseOnFinish = new RefCountingRunnable(onInferenceCompletion)) { + for (var entry : inferenceRequests.entrySet()) { + executeShardBulkInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire()); + } + } + } + + private void executeShardBulkInferenceAsync( + final String inferenceId, + @Nullable InferenceProvider inferenceProvider, + final List requests, + final Releasable onFinish + ) { + if (inferenceProvider == null) { + ActionListener modelLoadingListener = new ActionListener<>() { + @Override + public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { + var service = inferenceServiceRegistry.getService(unparsedModel.service()); + if (service.isEmpty() == false) { + var provider = new InferenceProvider( + service.get(), + service.get() + .parsePersistedConfigWithSecrets( + inferenceId, + unparsedModel.taskType(), + unparsedModel.settings(), + unparsedModel.secrets() + ) + ); + executeShardBulkInferenceAsync(inferenceId, provider, requests, onFinish); + } else { + try (onFinish) { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + inferenceResults.get(request.id).failures.add( + new ResourceNotFoundException( + "Inference id [{}] not found for field [{}]", + inferenceId, + request.field + ) + ); + } + } + } + } + + @Override + public void onFailure(Exception exc) { + try (onFinish) { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + inferenceResults.get(request.id).failures.add( + new ResourceNotFoundException("Inference id [{}] not found for field [{}]", inferenceId, request.field) + ); + } + } + } + }; + modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); + return; + } + final List inputs = requests.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); + ActionListener> completionListener = new ActionListener<>() { + @Override + public void onResponse(List results) { + for (int i = 0; i < results.size(); i++) { + var request = requests.get(i); + var result = results.get(i); + var acc = inferenceResults.get(request.id); + acc.responses.add(new FieldInferenceResponse(request.field, inferenceProvider.model, result)); + } + } + + @Override + public void onFailure(Exception exc) { + for (int i = 0; i < requests.size(); i++) { + var request = requests.get(i); + inferenceResults.get(request.id).failures.add( + new ElasticsearchException( + "Exception when running inference id [{}] on field [{}]", + exc, + inferenceProvider.model.getInferenceEntityId(), + request.field + ) + ); + } + } + }; + inferenceProvider.service() + .chunkedInfer( + inferenceProvider.model(), + inputs, + Map.of(), + InputType.INGEST, + new ChunkingOptions(null, null), + ActionListener.runAfter(completionListener, onFinish::close) + ); + } + + /** + * Applies the {@link FieldInferenceResponseAccumulator} to the provider {@link BulkItemRequest}. + * If the response contains failures, the bulk item request is mark as failed for the downstream action. + * Otherwise, the source of the request is augmented with the field inference results. + */ + private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) { + if (response.failures().isEmpty() == false) { + for (var failure : response.failures()) { + item.abort(item.index(), failure); + } + return; + } + + final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); + Map newDocMap = indexRequest.sourceAsMap(); + Map inferenceMap = new LinkedHashMap<>(); + // ignore the existing inference map if any + newDocMap.put(InferenceResultFieldMapper.NAME, inferenceMap); + for (FieldInferenceResponse fieldResponse : response.responses()) { + try { + InferenceResultFieldMapper.applyFieldInference( + inferenceMap, + fieldResponse.field(), + fieldResponse.model(), + fieldResponse.chunkedResults() + ); + } catch (Exception exc) { + item.abort(item.index(), exc); + } + } + indexRequest.source(newDocMap); + } + + private Map> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) { + Map> fieldRequestsMap = new LinkedHashMap<>(); + for (var item : bulkShardRequest.items()) { + if (item.getPrimaryResponse() != null) { + // item was already aborted/processed by a filter in the chain upstream (e.g. security) + continue; + } + final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); + if (indexRequest == null) { + continue; + } + final Map docMap = indexRequest.sourceAsMap(); + for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { + String field = entry.getKey(); + String inferenceId = entry.getValue().inferenceId(); + var value = XContentMapValues.extractValue(field, docMap); + if (value == null) { + continue; + } + if (inferenceResults.get(item.id()) == null) { + inferenceResults.set( + item.id(), + new FieldInferenceResponseAccumulator( + item.id(), + Collections.synchronizedList(new ArrayList<>()), + Collections.synchronizedList(new ArrayList<>()) + ) + ); + } + if (value instanceof String valueStr) { + List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); + fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); + } else { + inferenceResults.get(item.id()).failures.add( + new ElasticsearchStatusException( + "Invalid format for field [{}], expected [String] got [{}]", + RestStatus.BAD_REQUEST, + field, + value.getClass().getSimpleName() + ) + ); + } + } + } + return fieldRequestsMap; + } + } + + static IndexRequest getIndexRequestOrNull(DocWriteRequest docWriteRequest) { + if (docWriteRequest instanceof IndexRequest indexRequest) { + return indexRequest; + } else if (docWriteRequest instanceof UpdateRequest updateRequest) { + return updateRequest.doc(); + } else { + return null; + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java similarity index 84% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java index ad1e0f8c8cb81..2ede5419ab74e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java @@ -8,7 +8,8 @@ package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.search.Query; -import org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.Strings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.DocumentParserContext; @@ -28,23 +29,27 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.inference.SemanticTextModelSettings; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.ROOT_INFERENCE_FIELD; - /** * A mapper for the {@code _semantic_text_inference} field. *
@@ -57,7 +62,7 @@ * { * "_source": { * "my_semantic_text_field": "these are not the droids you're looking for", - * "_semantic_text_inference": { + * "_inference": { * "my_semantic_text_field": [ * { * "sparse_embedding": { @@ -100,12 +105,17 @@ * } * */ -public class SemanticTextInferenceResultFieldMapper extends MetadataFieldMapper { - public static final String CONTENT_TYPE = "_semantic_text_inference"; - public static final String NAME = ROOT_INFERENCE_FIELD; - public static final TypeParser PARSER = new FixedTypeParser(c -> new SemanticTextInferenceResultFieldMapper()); +public class InferenceResultFieldMapper extends MetadataFieldMapper { + public static final String NAME = "_inference"; + public static final String CONTENT_TYPE = "_inference"; + + public static final String RESULTS = "results"; + public static final String INFERENCE_CHUNKS_RESULTS = "inference"; + public static final String INFERENCE_CHUNKS_TEXT = "text"; + + public static final TypeParser PARSER = new FixedTypeParser(c -> new InferenceResultFieldMapper()); - private static final Logger logger = LogManager.getLogger(SemanticTextInferenceResultFieldMapper.class); + private static final Logger logger = LogManager.getLogger(InferenceResultFieldMapper.class); private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); @@ -132,7 +142,7 @@ public Query termQuery(Object value, SearchExecutionContext context) { } } - private SemanticTextInferenceResultFieldMapper() { + public InferenceResultFieldMapper() { super(SemanticTextInferenceFieldType.INSTANCE); } @@ -173,7 +183,7 @@ private static void parseSingleField(DocumentParserContext context, MapperBuilde failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); String currentName = parser.currentName(); - if (BulkShardRequestInferenceProvider.INFERENCE_RESULTS.equals(currentName)) { + if (RESULTS.equals(currentName)) { NestedObjectMapper nestedObjectMapper = createInferenceResultsObjectMapper( context, mapperBuilderContext, @@ -329,4 +339,34 @@ protected String contentType() { public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { return SourceLoader.SyntheticFieldLoader.NOTHING; } + + public static void applyFieldInference( + Map inferenceMap, + String field, + Model model, + ChunkedInferenceServiceResults results + ) throws ElasticsearchException { + List> chunks = new ArrayList<>(); + if (results instanceof ChunkedSparseEmbeddingResults textExpansionResults) { + for (var chunk : textExpansionResults.getChunkedResults()) { + chunks.add(chunk.asMap()); + } + } else if (results instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { + for (var chunk : textEmbeddingResults.getChunks()) { + chunks.add(chunk.asMap()); + } + } else { + throw new ElasticsearchStatusException( + "Invalid inference results format for field [{}] with inference id [{}], got {}", + RestStatus.BAD_REQUEST, + field, + model.getInferenceEntityId(), + results.getWriteableName() + ); + } + Map fieldMap = new LinkedHashMap<>(); + fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); + fieldMap.put(InferenceResultFieldMapper.RESULTS, chunks); + inferenceMap.put(field, fieldMap); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index d9e18728615ba..83272a10f98d4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -30,7 +30,7 @@ * at ingestion and query time. * For now, it is compatible with text expansion models only, but will be extended to support dense vector models as well. * This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will - * be indexed using {@link SemanticTextInferenceResultFieldMapper}. + * be indexed using {@link InferenceResultFieldMapper}. */ public class SemanticTextFieldMapper extends FieldMapper { diff --git a/server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java similarity index 92% rename from server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java index 3561c2351427c..1b6bb22c0d6b5 100644 --- a/server/src/main/java/org/elasticsearch/inference/SemanticTextModelSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java @@ -1,13 +1,15 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. */ -package org.elasticsearch.inference; +package org.elasticsearch.xpack.inference.mapper; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; @@ -19,7 +21,6 @@ /** * Serialization class for specifying the settings of a model from semantic_text inference to field mapper. - * See {@link org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider} */ public class SemanticTextModelSettings { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java similarity index 86% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 40921cd38f181..0f3aa5b82b189 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -24,7 +24,6 @@ import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.OriginSettingClient; -import org.elasticsearch.common.inject.Inject; import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; @@ -32,7 +31,6 @@ import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; @@ -57,21 +55,49 @@ import static org.elasticsearch.core.Strings.format; -public class ModelRegistryImpl implements ModelRegistry { +public class ModelRegistry { public record ModelConfigMap(Map config, Map secrets) {} + /** + * Semi parsed model where inference entity id, task type and service + * are known but the settings are not parsed. + */ + public record UnparsedModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map settings, + Map secrets + ) { + + public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) { + if (modelConfigMap.config() == null) { + throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST); + } + String inferenceEntityId = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.MODEL_ID); + String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE); + String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME); + TaskType taskType = TaskType.fromString(taskTypeStr); + + return new UnparsedModel(inferenceEntityId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets()); + } + } + private static final String TASK_TYPE_FIELD = "task_type"; private static final String MODEL_ID_FIELD = "model_id"; - private static final Logger logger = LogManager.getLogger(ModelRegistryImpl.class); + private static final Logger logger = LogManager.getLogger(ModelRegistry.class); private final OriginSettingClient client; - @Inject - public ModelRegistryImpl(Client client) { + public ModelRegistry(Client client) { this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); } - @Override + /** + * Get a model with its secret settings + * @param inferenceEntityId Model to get + * @param listener Model listener + */ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // There should be a hit for the configurations and secrets @@ -80,7 +106,7 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // There should be a hit for the configurations and secrets @@ -101,7 +132,7 @@ public void getModel(String inferenceEntityId, ActionListener lis return; } - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistryImpl::unparsedModelFromMap).toList(); + var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList(); assert modelConfigs.size() == 1; delegate.onResponse(modelConfigs.get(0)); }); @@ -116,7 +147,12 @@ public void getModel(String inferenceEntityId, ActionListener lis client.search(modelSearch, searchListener); } - @Override + /** + * Get all models of a particular task type. + * Secret settings are not included + * @param taskType The task type + * @param listener Models listener + */ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // Not an error if no models of this task_type @@ -125,7 +161,7 @@ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // Not an error if no models of this task_type @@ -150,7 +190,7 @@ public void getAllModels(ActionListener> listener) { return; } - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistryImpl::unparsedModelFromMap).toList(); + var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList(); delegate.onResponse(modelConfigs); }); @@ -217,7 +257,6 @@ private ModelConfigMap createModelConfigMap(SearchHits hits, String inferenceEnt ); } - @Override public void storeModel(Model model, ActionListener listener) { ActionListener bulkResponseActionListener = getStoreModelListener(model, listener); @@ -314,7 +353,6 @@ private static BulkItemResponse.Failure getFirstBulkFailure(BulkResponse bulkRes return null; } - @Override public void deleteModel(String inferenceEntityId, ActionListener listener) { DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false); request.indices(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN); @@ -339,16 +377,4 @@ private static IndexRequest createIndexRequest(String docId, String indexName, T private QueryBuilder documentIdQuery(String inferenceEntityId) { return QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(inferenceEntityId))); } - - private static UnparsedModel unparsedModelFromMap(ModelRegistryImpl.ModelConfigMap modelConfigMap) { - if (modelConfigMap.config() == null) { - throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST); - } - String modelId = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.MODEL_ID); - String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE); - String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME); - TaskType taskType = TaskType.fromString(taskTypeStr); - - return new UnparsedModel(modelId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets()); - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java new file mode 100644 index 0000000000000..4a1825303b5a7 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -0,0 +1,344 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action.filter; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.bulk.BulkItemRequest; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkShardRequest; +import org.elasticsearch.action.bulk.TransportShardBulkAction; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.ActionFilterChain; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.model.TestModel; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.After; +import org.junit.Before; +import org.mockito.stubbing.Answer; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomSparseEmbeddings; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomTextEmbeddings; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ShardBulkInferenceActionFilterTests extends ESTestCase { + private ThreadPool threadPool; + + @Before + public void setupThreadPool() { + threadPool = new TestThreadPool(getTestName()); + } + + @After + public void tearDownThreadPool() throws Exception { + terminate(threadPool); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testFilterNoop() throws Exception { + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of()); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + assertNull(((BulkShardRequest) request).getFieldsInferenceMetadataMap()); + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + BulkShardRequest request = new BulkShardRequest( + new ShardId("test", "test", 0), + WriteRequest.RefreshPolicy.NONE, + new BulkItemRequest[0] + ); + request.setFieldInferenceMetadata( + new FieldInferenceMetadata(Map.of("foo", new FieldInferenceMetadata.FieldInferenceOptions("bar", Set.of()))) + ); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testInferenceNotFound() throws Exception { + StaticModel model = randomStaticModel(); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(model.getInferenceEntityId(), model)); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getFieldsInferenceMetadataMap()); + for (BulkItemRequest item : bulkShardRequest.items()) { + assertNotNull(item.getPrimaryResponse()); + assertTrue(item.getPrimaryResponse().isFailed()); + BulkItemResponse.Failure failure = item.getPrimaryResponse().getFailure(); + assertThat(failure.getStatus(), equalTo(RestStatus.NOT_FOUND)); + } + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + + FieldInferenceMetadata inferenceFields = new FieldInferenceMetadata( + Map.of( + "field1", + new FieldInferenceMetadata.FieldInferenceOptions(model.getInferenceEntityId(), Set.of()), + "field2", + new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of()), + "field3", + new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of()) + ) + ); + BulkItemRequest[] items = new BulkItemRequest[10]; + for (int i = 0; i < items.length; i++) { + items[i] = randomBulkItemRequest(i, Map.of(), inferenceFields)[0]; + } + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); + request.setFieldInferenceMetadata(inferenceFields); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testManyRandomDocs() throws Exception { + Map inferenceModelMap = new HashMap<>(); + int numModels = randomIntBetween(1, 5); + for (int i = 0; i < numModels; i++) { + StaticModel model = randomStaticModel(); + inferenceModelMap.put(model.getInferenceEntityId(), model); + } + + int numInferenceFields = randomIntBetween(1, 5); + Map inferenceFieldsMap = new HashMap<>(); + for (int i = 0; i < numInferenceFields; i++) { + String field = randomAlphaOfLengthBetween(5, 10); + String inferenceId = randomFrom(inferenceModelMap.keySet()); + inferenceFieldsMap.put(field, new FieldInferenceMetadata.FieldInferenceOptions(inferenceId, Set.of())); + } + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata(inferenceFieldsMap); + + int numRequests = randomIntBetween(100, 1000); + BulkItemRequest[] originalRequests = new BulkItemRequest[numRequests]; + BulkItemRequest[] modifiedRequests = new BulkItemRequest[numRequests]; + for (int id = 0; id < numRequests; id++) { + BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, fieldInferenceMetadata); + originalRequests[id] = res[0]; + modifiedRequests[id] = res[1]; + } + + ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + assertThat(request, instanceOf(BulkShardRequest.class)); + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getFieldsInferenceMetadataMap()); + BulkItemRequest[] items = bulkShardRequest.items(); + assertThat(items.length, equalTo(originalRequests.length)); + for (int id = 0; id < items.length; id++) { + IndexRequest actualRequest = ShardBulkInferenceActionFilter.getIndexRequestOrNull(items[id].request()); + IndexRequest expectedRequest = ShardBulkInferenceActionFilter.getIndexRequestOrNull(modifiedRequests[id].request()); + try { + assertToXContentEquivalent(expectedRequest.source(), actualRequest.source(), actualRequest.getContentType()); + } catch (Exception exc) { + throw new IllegalStateException(exc); + } + } + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + BulkShardRequest original = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, originalRequests); + original.setFieldInferenceMetadata(fieldInferenceMetadata); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, original, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings("unchecked") + private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool, Map modelMap) { + ModelRegistry modelRegistry = mock(ModelRegistry.class); + Answer unparsedModelAnswer = invocationOnMock -> { + String id = (String) invocationOnMock.getArguments()[0]; + ActionListener listener = (ActionListener) invocationOnMock + .getArguments()[1]; + var model = modelMap.get(id); + if (model != null) { + listener.onResponse( + new ModelRegistry.UnparsedModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getServiceSettings().model(), + XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(model.getTaskSettings()), false), + XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(model.getSecretSettings()), false) + ) + ); + } else { + listener.onFailure(new ResourceNotFoundException("model id [{}] not found", id)); + } + return null; + }; + doAnswer(unparsedModelAnswer).when(modelRegistry).getModelWithSecrets(any(), any()); + + InferenceService inferenceService = mock(InferenceService.class); + Answer chunkedInferAnswer = invocationOnMock -> { + StaticModel model = (StaticModel) invocationOnMock.getArguments()[0]; + List inputs = (List) invocationOnMock.getArguments()[1]; + ActionListener> listener = (ActionListener< + List>) invocationOnMock.getArguments()[5]; + Runnable runnable = () -> { + List results = new ArrayList<>(); + for (String input : inputs) { + results.add(model.getResults(input)); + } + listener.onResponse(results); + }; + if (randomBoolean()) { + try { + threadPool.generic().execute(runnable); + } catch (Exception exc) { + listener.onFailure(exc); + } + } else { + runnable.run(); + } + return null; + }; + doAnswer(chunkedInferAnswer).when(inferenceService).chunkedInfer(any(), any(), any(), any(), any(), any()); + + Answer modelAnswer = invocationOnMock -> { + String inferenceId = (String) invocationOnMock.getArguments()[0]; + return modelMap.get(inferenceId); + }; + doAnswer(modelAnswer).when(inferenceService).parsePersistedConfigWithSecrets(any(), any(), any(), any()); + + InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); + when(inferenceServiceRegistry.getService(any())).thenReturn(Optional.of(inferenceService)); + ShardBulkInferenceActionFilter filter = new ShardBulkInferenceActionFilter(inferenceServiceRegistry, modelRegistry); + return filter; + } + + private static BulkItemRequest[] randomBulkItemRequest( + int id, + Map modelMap, + FieldInferenceMetadata fieldInferenceMetadata + ) { + Map docMap = new LinkedHashMap<>(); + Map inferenceResultsMap = new LinkedHashMap<>(); + for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { + String field = entry.getKey(); + var model = modelMap.get(entry.getValue().inferenceId()); + String text = randomAlphaOfLengthBetween(10, 100); + docMap.put(field, text); + if (model == null) { + // ignore results, the doc should fail with a resource not found exception + continue; + } + int numChunks = randomIntBetween(1, 5); + List chunks = new ArrayList<>(); + for (int i = 0; i < numChunks; i++) { + chunks.add(randomAlphaOfLengthBetween(5, 10)); + } + TaskType taskType = model.getTaskType(); + final ChunkedInferenceServiceResults results; + switch (taskType) { + case TEXT_EMBEDDING: + results = randomTextEmbeddings(chunks); + break; + + case SPARSE_EMBEDDING: + results = randomSparseEmbeddings(chunks); + break; + + default: + throw new AssertionError("Unknown task type " + taskType.name()); + } + model.putResult(text, results); + InferenceResultFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); + } + Map expectedDocMap = new LinkedHashMap<>(docMap); + expectedDocMap.put(InferenceResultFieldMapper.NAME, inferenceResultsMap); + return new BulkItemRequest[] { + new BulkItemRequest(id, new IndexRequest("index").source(docMap)), + new BulkItemRequest(id, new IndexRequest("index").source(expectedDocMap)) }; + } + + private static StaticModel randomStaticModel() { + String serviceName = randomAlphaOfLengthBetween(5, 10); + String inferenceId = randomAlphaOfLengthBetween(5, 10); + return new StaticModel( + inferenceId, + randomBoolean() ? TaskType.TEXT_EMBEDDING : TaskType.SPARSE_EMBEDDING, + serviceName, + new TestModel.TestServiceSettings("my-model"), + new TestModel.TestTaskSettings(randomIntBetween(1, 100)), + new TestModel.TestSecretSettings(randomAlphaOfLength(10)) + ); + } + + private static class StaticModel extends TestModel { + private final Map resultMap; + + StaticModel( + String inferenceEntityId, + TaskType taskType, + String service, + TestServiceSettings serviceSettings, + TestTaskSettings taskSettings, + TestSecretSettings secretSettings + ) { + super(inferenceEntityId, taskType, service, serviceSettings, taskSettings, secretSettings); + this.resultMap = new HashMap<>(); + } + + ChunkedInferenceServiceResults getResults(String text) { + return resultMap.getOrDefault(text, new ChunkedSparseEmbeddingResults(List.of())); + } + + void putResult(String text, ChunkedInferenceServiceResults results) { + resultMap.put(text, results); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java similarity index 79% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java index 319f6ef73fa56..b5d75b528c6ab 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java @@ -31,49 +31,46 @@ import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ParsedDocument; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; -import org.elasticsearch.inference.SemanticTextModelSettings; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.LeafNestedDocuments; import org.elasticsearch.search.NestedDocuments; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.model.TestModel; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Consumer; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider.INFERENCE_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.RESULTS; import static org.hamcrest.Matchers.containsString; -public class SemanticTextInferenceResultFieldMapperTests extends MetadataMapperTestCase { - private record SemanticTextInferenceResults(String fieldName, SparseEmbeddingResults sparseEmbeddingResults, List text) { - private SemanticTextInferenceResults { - if (sparseEmbeddingResults.embeddings().size() != text.size()) { - throw new IllegalArgumentException("Sparse embeddings and text must be the same size"); - } - } - } +public class InferenceResultFieldMapperTests extends MetadataMapperTestCase { + private record SemanticTextInferenceResults(String fieldName, ChunkedInferenceServiceResults results, List text) {} - private record VisitedChildDocInfo(String path, int sparseVectorDims) {} + private record VisitedChildDocInfo(String path, int numChunks) {} private record SparseVectorSubfieldOptions(boolean include, boolean includeEmbedding, boolean includeIsTruncated) {} @Override protected String fieldName() { - return SemanticTextInferenceResultFieldMapper.NAME; + return InferenceResultFieldMapper.NAME; } @Override @@ -109,8 +106,8 @@ public void testSuccessfulParse() throws IOException { b -> addSemanticTextInferenceResults( b, List.of( - generateSemanticTextinferenceResults(fieldName1, List.of("a b", "c")), - generateSemanticTextinferenceResults(fieldName2, List.of("d e f")) + randomSemanticTextInferenceResults(fieldName1, List.of("a b", "c")), + randomSemanticTextInferenceResults(fieldName2, List.of("d e f")) ) ) ) @@ -209,10 +206,10 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), true, - null + Map.of() ) ) ) @@ -227,10 +224,10 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), new SparseVectorSubfieldOptions(true, true, true), false, - null + Map.of() ) ) ) @@ -245,10 +242,10 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), false, - null + Map.of() ) ) ) @@ -263,7 +260,7 @@ public void testMissingSubfields() throws IOException { public void testExtraSubfields() throws IOException { final String fieldName = randomAlphaOfLengthBetween(5, 15); final List semanticTextInferenceResultsList = List.of( - generateSemanticTextinferenceResults(fieldName, List.of("a b")) + randomSemanticTextInferenceResults(fieldName, List.of("a b")) ); DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, fieldName, randomAlphaOfLength(8)))); @@ -361,7 +358,7 @@ public void testMissingSemanticTextMapping() throws IOException { DocumentParsingException.class, DocumentParsingException.class, () -> documentMapper.parse( - source(b -> addSemanticTextInferenceResults(b, List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))))) + source(b -> addSemanticTextInferenceResults(b, List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))))) ) ); assertThat( @@ -379,18 +376,32 @@ private static void addSemanticTextMapping(XContentBuilder mappingBuilder, Strin mappingBuilder.endObject(); } - private static SemanticTextInferenceResults generateSemanticTextinferenceResults(String semanticTextFieldName, List chunks) { - List embeddings = new ArrayList<>(chunks.size()); - for (String chunk : chunks) { - String[] tokens = chunk.split("\\s+"); - List weightedTokens = Arrays.stream(tokens) - .map(t -> new SparseEmbeddingResults.WeightedToken(t, randomFloat())) - .toList(); + public static ChunkedTextEmbeddingResults randomTextEmbeddings(List inputs) { + List chunks = new ArrayList<>(); + for (String input : inputs) { + double[] values = new double[5]; + for (int j = 0; j < values.length; j++) { + values[j] = randomDouble(); + } + chunks.add(new org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk(input, values)); + } + return new ChunkedTextEmbeddingResults(chunks); + } - embeddings.add(new SparseEmbeddingResults.Embedding(weightedTokens, false)); + public static ChunkedSparseEmbeddingResults randomSparseEmbeddings(List inputs) { + List chunks = new ArrayList<>(); + for (String input : inputs) { + var tokens = new ArrayList(); + for (var token : input.split("\\s+")) { + tokens.add(new TextExpansionResults.WeightedToken(token, randomFloat())); + } + chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input, tokens)); } + return new ChunkedSparseEmbeddingResults(chunks); + } - return new SemanticTextInferenceResults(semanticTextFieldName, new SparseEmbeddingResults(embeddings), chunks); + private static SemanticTextInferenceResults randomSemanticTextInferenceResults(String semanticTextFieldName, List chunks) { + return new SemanticTextInferenceResults(semanticTextFieldName, randomSparseEmbeddings(chunks), chunks); } private static void addSemanticTextInferenceResults( @@ -402,10 +413,11 @@ private static void addSemanticTextInferenceResults( semanticTextInferenceResults, new SparseVectorSubfieldOptions(true, true, true), true, - null + Map.of() ); } + @SuppressWarnings("unchecked") private static void addSemanticTextInferenceResults( XContentBuilder sourceBuilder, List semanticTextInferenceResults, @@ -413,48 +425,39 @@ private static void addSemanticTextInferenceResults( boolean includeTextSubfield, Map extraSubfields ) throws IOException { - - Map> inferenceResultsMap = new HashMap<>(); + Map inferenceResultsMap = new HashMap<>(); for (SemanticTextInferenceResults semanticTextInferenceResult : semanticTextInferenceResults) { - Map fieldMap = new HashMap<>(); - fieldMap.put(SemanticTextModelSettings.NAME, modelSettingsMap()); - List> parsedInferenceResults = new ArrayList<>(semanticTextInferenceResult.text().size()); - - Iterator embeddingsIterator = semanticTextInferenceResult.sparseEmbeddingResults() - .embeddings() - .iterator(); - Iterator textIterator = semanticTextInferenceResult.text().iterator(); - while (embeddingsIterator.hasNext() && textIterator.hasNext()) { - SparseEmbeddingResults.Embedding embedding = embeddingsIterator.next(); - String text = textIterator.next(); - - Map subfieldMap = new HashMap<>(); - if (sparseVectorSubfieldOptions.include()) { - subfieldMap.put(INFERENCE_CHUNKS_RESULTS, embedding.asMap().get(SparseEmbeddingResults.Embedding.EMBEDDING)); - } - if (includeTextSubfield) { - subfieldMap.put(INFERENCE_CHUNKS_TEXT, text); + InferenceResultFieldMapper.applyFieldInference( + inferenceResultsMap, + semanticTextInferenceResult.fieldName, + randomModel(), + semanticTextInferenceResult.results + ); + Map optionsMap = (Map) inferenceResultsMap.get(semanticTextInferenceResult.fieldName); + List> fieldResultList = (List>) optionsMap.get(RESULTS); + for (var entry : fieldResultList) { + if (includeTextSubfield == false) { + entry.remove(INFERENCE_CHUNKS_TEXT); } - if (extraSubfields != null) { - subfieldMap.putAll(extraSubfields); + if (sparseVectorSubfieldOptions.include == false) { + entry.remove(INFERENCE_CHUNKS_RESULTS); } - - parsedInferenceResults.add(subfieldMap); + entry.putAll(extraSubfields); } - - fieldMap.put(INFERENCE_RESULTS, parsedInferenceResults); - inferenceResultsMap.put(semanticTextInferenceResult.fieldName(), fieldMap); } - - sourceBuilder.field(SemanticTextInferenceResultFieldMapper.NAME, inferenceResultsMap); + sourceBuilder.field(InferenceResultFieldMapper.NAME, inferenceResultsMap); } - private static Map modelSettingsMap() { - return Map.of( - SemanticTextModelSettings.TASK_TYPE_FIELD.getPreferredName(), - TaskType.SPARSE_EMBEDDING.toString(), - SemanticTextModelSettings.INFERENCE_ID_FIELD.getPreferredName(), - randomAlphaOfLength(8) + private static Model randomModel() { + String serviceName = randomAlphaOfLengthBetween(5, 10); + String inferenceId = randomAlphaOfLengthBetween(5, 10); + return new TestModel( + inferenceId, + TaskType.SPARSE_EMBEDDING, + serviceName, + new TestModel.TestServiceSettings("my-model"), + new TestModel.TestTaskSettings(randomIntBetween(1, 100)), + new TestModel.TestSecretSettings(randomAlphaOfLength(10)) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImplTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java similarity index 92% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImplTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index fd6a203450c12..2417148c84ac2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImplTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -45,7 +45,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class ModelRegistryImplTests extends ESTestCase { +public class ModelRegistryTests extends ESTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); @@ -65,9 +65,9 @@ public void testGetUnparsedModelMap_ThrowsResourceNotFound_WhenNoHitsReturned() var client = mockClient(); mockClientExecuteSearch(client, mockSearchResponse(SearchHits.EMPTY)); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); ResourceNotFoundException exception = expectThrows(ResourceNotFoundException.class, () -> listener.actionGet(TIMEOUT)); @@ -79,9 +79,9 @@ public void testGetUnparsedModelMap_ThrowsIllegalArgumentException_WhenInvalidIn var unknownIndexHit = SearchHit.createFromMap(Map.of("_index", "unknown_index")); mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { unknownIndexHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> listener.actionGet(TIMEOUT)); @@ -96,9 +96,9 @@ public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFind var inferenceSecretsHit = SearchHit.createFromMap(Map.of("_index", ".secrets-inference")); mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceSecretsHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); @@ -113,9 +113,9 @@ public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFind var inferenceHit = SearchHit.createFromMap(Map.of("_index", ".inference")); mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); @@ -147,9 +147,9 @@ public void testGetModelWithSecrets() { mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit, inferenceSecretsHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); var modelConfig = listener.actionGet(TIMEOUT); @@ -176,9 +176,9 @@ public void testGetModelNoSecrets() { mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit })); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModel("1", listener); registry.getModel("1", listener); @@ -201,7 +201,7 @@ public void testStoreModel_ReturnsTrue_WhenNoFailuresOccur() { mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); @@ -218,7 +218,7 @@ public void testStoreModel_ThrowsException_WhenBulkResponseIsEmpty() { mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); @@ -249,7 +249,7 @@ public void testStoreModel_ThrowsResourceAlreadyExistsException_WhenFailureIsAVe mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); @@ -275,7 +275,7 @@ public void testStoreModel_ThrowsException_WhenFailureIsNotAVersionConflict() { mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistryImpl(client); + var registry = new ModelRegistry(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index ead7f904ad57b..6008ebbcbedf8 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -83,11 +83,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } + - match: { _source._inference.inference_field.results.0.text: "inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } - - exists: _source._semantic_text_inference.inference_field.inference_results.0.inference - - exists: _source._semantic_text_inference.another_inference_field.inference_results.0.inference + - exists: _source._inference.inference_field.results.0.inference + - exists: _source._inference.another_inference_field.results.0.inference --- "text expansion documents do not create new mappings": @@ -120,11 +120,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } + - match: { _source._inference.inference_field.results.0.text: "inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } - - exists: _source._semantic_text_inference.inference_field.inference_results.0.inference - - exists: _source._semantic_text_inference.another_inference_field.inference_results.0.inference + - exists: _source._inference.inference_field.results.0.inference + - exists: _source._inference.another_inference_field.results.0.inference --- @@ -154,8 +154,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._semantic_text_inference.inference_field.inference_results.0.inference: inference_field_embedding } - - set: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: another_inference_field_embedding } + - set: { _source._inference.inference_field.results.0.inference: inference_field_embedding } + - set: { _source._inference.another_inference_field.results.0.inference: another_inference_field_embedding } - do: update: @@ -174,11 +174,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "another non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } + - match: { _source._inference.inference_field.results.0.text: "inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.inference: $inference_field_embedding } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.results.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.results.0.inference: $another_inference_field_embedding } --- "Updating semantic_text fields recalculates embeddings": @@ -214,8 +214,8 @@ setup: - match: { _source.another_inference_field: "another updated inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "updated inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another updated inference test" } + - match: { _source._inference.inference_field.results.0.text: "updated inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another updated inference test" } --- "Reindex works for semantic_text fields": @@ -233,8 +233,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._semantic_text_inference.inference_field.inference_results.0.inference: inference_field_embedding } - - set: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: another_inference_field_embedding } + - set: { _source._inference.inference_field.results.0.inference: inference_field_embedding } + - set: { _source._inference.another_inference_field.results.0.inference: another_inference_field_embedding } - do: indices.refresh: { } @@ -271,11 +271,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.text: "inference test" } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.text: "another inference test" } + - match: { _source._inference.inference_field.results.0.text: "inference test" } + - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } - - match: { _source._semantic_text_inference.inference_field.inference_results.0.inference: $inference_field_embedding } - - match: { _source._semantic_text_inference.another_inference_field.inference_results.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.results.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.results.0.inference: $another_inference_field_embedding } --- "Fails for non-existent model": @@ -292,7 +292,7 @@ setup: type: text - do: - catch: bad_request + catch: missing index: index: incorrect-test-sparse-index id: doc_1 @@ -300,7 +300,7 @@ setup: inference_field: "inference test" non_inference_field: "non inference test" - - match: { error.reason: "No inference provider found for model ID non-existing-inference-id" } + - match: { error.reason: "Inference id [non-existing-inference-id] not found for field [inference_field]" } # Succeeds when semantic_text field is not used - do: diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml index da61e6e403ed8..2c69f49218091 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -56,12 +56,12 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: sparse_field: model_settings: inference_id: sparse-inference-id task_type: sparse_embedding - inference_results: + results: - text: "inference test" inference: feature_1: 0.1 @@ -83,14 +83,14 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: dense_field: model_settings: inference_id: sparse-inference-id task_type: text_embedding dimensions: 5 similarity: cosine - inference_results: + results: - text: "inference test" inference: [0.1, 0.2, 0.3, 0.4, 0.5] - text: "another inference test" @@ -105,11 +105,11 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: sparse_field: model_settings: task_type: sparse_embedding - inference_results: + results: - text: "inference test" inference: feature_1: 0.1 @@ -123,11 +123,11 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: sparse_field: model_settings: inference_id: sparse-inference-id - inference_results: + results: - text: "inference test" inference: feature_1: 0.1 @@ -141,12 +141,12 @@ setup: id: doc_1 body: non_inference_field: "you know, for testing" - _semantic_text_inference: + _inference: dense_field: model_settings: inference_id: sparse-inference-id task_type: text_embedding - inference_results: + results: - text: "inference test" inference: [0.1, 0.2, 0.3, 0.4, 0.5] - text: "another inference test"