From 86ddc9d8b8aff0285890732d2164c589b87ce3dc Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Mon, 18 Mar 2024 15:12:23 +0000 Subject: [PATCH] add more tests --- .../action/bulk/BulkShardRequest.java | 13 +- .../vectors/DenseVectorFieldMapper.java | 4 + .../xpack/inference/InferencePlugin.java | 4 +- .../ShardBulkInferenceActionFilter.java | 133 ++++--- ...r.java => InferenceResultFieldMapper.java} | 55 ++- .../mapper/SemanticTextFieldMapper.java | 2 +- .../ShardBulkInferenceActionFilterTests.java | 340 ++++++++++++++++++ ...a => InferenceResultFieldMapperTests.java} | 146 ++++---- 8 files changed, 544 insertions(+), 153 deletions(-) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/{SemanticTextInferenceResultFieldMapper.java => InferenceResultFieldMapper.java} (86%) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/{SemanticTextInferenceResultFieldMapperTests.java => InferenceResultFieldMapperTests.java} (79%) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java index f6dd7902f1672..1b5494c6a68f5 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java @@ -48,10 +48,10 @@ public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRe } /** - * Set the transient metadata indicating that this request requires running inference - * before proceeding. + * Public for test + * Set the transient metadata indicating that this request requires running inference before proceeding. */ - void setFieldInferenceMetadata(Map> fieldsInferenceMetadata) { + public void setFieldInferenceMetadata(Map> fieldsInferenceMetadata) { this.fieldsInferenceMetadata = fieldsInferenceMetadata; } @@ -64,6 +64,13 @@ public Map> consumeFieldInferenceMetadata() { return ret; } + /** + * Public for test + */ + public Map> getFieldsInferenceMetadata() { + return fieldsInferenceMetadata; + } + public long totalSizeInBytes() { long totalSizeInBytes = 0; for (int i = 0; i < items.length; i++) { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index c6e4d4af926a2..53cc803fc5a2f 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -1086,6 +1086,10 @@ public String typeName() { return CONTENT_TYPE; } + public Integer getDims() { + return dims; + } + @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { if (format != null) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 7bfa06ecb9a20..994207766f2a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -55,8 +55,8 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory; import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; -import org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; @@ -285,7 +285,7 @@ public Map getMappers() { @Override public Map getMetadataMappers() { - return Map.of(SemanticTextInferenceResultFieldMapper.NAME, SemanticTextInferenceResultFieldMapper.PARSER); + return Map.of(InferenceResultFieldMapper.NAME, InferenceResultFieldMapper.PARSER); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 176d2917b0b2a..e679d3c970abf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; @@ -33,11 +34,9 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; -import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; -import org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper; -import org.elasticsearch.xpack.inference.mapper.SemanticTextModelSettings; +import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; @@ -50,7 +49,7 @@ /** * An {@link ActionFilter} that performs inference on {@link BulkShardRequest} asynchronously and stores the results in - * the individual {@link BulkItemRequest}. The results are then consumed by the {@link SemanticTextInferenceResultFieldMapper} + * the individual {@link BulkItemRequest}. The results are then consumed by the {@link InferenceResultFieldMapper} * in the subsequent {@link TransportShardBulkAction} downstream. */ public class ShardBulkInferenceActionFilter implements ActionFilter { @@ -82,7 +81,7 @@ public void app case TransportShardBulkAction.ACTION_NAME: BulkShardRequest bulkShardRequest = (BulkShardRequest) request; var fieldInferenceMetadata = bulkShardRequest.consumeFieldInferenceMetadata(); - if (fieldInferenceMetadata != null) { + if (fieldInferenceMetadata != null && fieldInferenceMetadata.size() > 0) { Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener); processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion); } else { @@ -110,18 +109,7 @@ private record FieldInferenceRequest(int id, String field, String input) {} private record FieldInferenceResponse(String field, Model model, ChunkedInferenceServiceResults chunkedResults) {} - private record FieldInferenceResponseAccumulator(int id, List responses, List failures) { - Exception createFailureOrNull() { - if (failures.isEmpty()) { - return null; - } - Exception main = failures.get(0); - for (int i = 1; i < failures.size(); i++) { - main.addSuppressed(failures.get(i)); - } - return main; - } - } + private record FieldInferenceResponseAccumulator(int id, List responses, List failures) {} private class AsyncBulkShardInferenceAction implements Runnable { private final Map> fieldInferenceMetadata; @@ -147,7 +135,11 @@ public void run() { try { for (var inferenceResponse : inferenceResults.asList()) { var request = bulkShardRequest.items()[inferenceResponse.id]; - applyInference(request, inferenceResponse); + try { + applyInferenceResponses(request, inferenceResponse); + } catch (Exception exc) { + request.abort(bulkShardRequest.index(), exc); + } } } finally { onCompletion.run(); @@ -189,8 +181,8 @@ public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { var request = requests.get(i); inferenceResults.get(request.id).failures.add( new ResourceNotFoundException( - "Inference service [{}] not found for field [{}]", - unparsedModel.service(), + "Inference id [{}] not found for field [{}]", + inferenceId, request.field ) ); @@ -221,9 +213,8 @@ public void onResponse(List results) { for (int i = 0; i < results.size(); i++) { var request = requests.get(i); var result = results.get(i); - inferenceResults.get(request.id).responses.add( - new FieldInferenceResponse(request.field, inferenceProvider.model, result) - ); + var acc = inferenceResults.get(request.id); + acc.responses.add(new FieldInferenceResponse(request.field, inferenceProvider.model, result)); } } @@ -254,38 +245,34 @@ public void onFailure(Exception exc) { } /** - * Apply the {@link FieldInferenceResponseAccumulator} to the provider {@link BulkItemRequest}. + * Applies the {@link FieldInferenceResponseAccumulator} to the provider {@link BulkItemRequest}. * If the response contains failures, the bulk item request is mark as failed for the downstream action. * Otherwise, the source of the request is augmented with the field inference results. */ - private void applyInference(BulkItemRequest request, FieldInferenceResponseAccumulator inferenceResponse) { - Exception failure = inferenceResponse.createFailureOrNull(); - if (failure != null) { - request.abort(bulkShardRequest.index(), failure); + private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) { + if (response.failures().isEmpty() == false) { + for (var failure : response.failures()) { + item.abort(item.index(), failure); + } return; } - final IndexRequest indexRequest = getIndexRequestOrNull(request.request()); - final Map newDocMap = indexRequest.sourceAsMap(); - final Map inferenceMetadataMap = new LinkedHashMap<>(); - newDocMap.put(SemanticTextInferenceResultFieldMapper.NAME, inferenceMetadataMap); - for (FieldInferenceResponse fieldResponse : inferenceResponse.responses) { - List> chunks = new ArrayList<>(); - if (fieldResponse.chunkedResults instanceof ChunkedSparseEmbeddingResults textExpansionResults) { - for (var chunk : textExpansionResults.getChunkedResults()) { - chunks.add(chunk.asMap()); - } - } else if (fieldResponse.chunkedResults instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { - for (var chunk : textEmbeddingResults.getChunks()) { - chunks.add(chunk.asMap()); - } - } else { - request.abort(bulkShardRequest.index(), new IllegalArgumentException("TODO")); - return; + + final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); + Map newDocMap = indexRequest.sourceAsMap(); + Map inferenceMap = new LinkedHashMap<>(); + // ignore the existing inference map if any + newDocMap.put(InferenceResultFieldMapper.NAME, inferenceMap); + for (FieldInferenceResponse fieldResponse : response.responses()) { + try { + InferenceResultFieldMapper.applyFieldInference( + inferenceMap, + fieldResponse.field(), + fieldResponse.model(), + fieldResponse.chunkedResults() + ); + } catch (Exception exc) { + item.abort(item.index(), exc); } - Map fieldMap = new LinkedHashMap<>(); - fieldMap.putAll(new SemanticTextModelSettings(fieldResponse.model).asMap()); - fieldMap.put(SemanticTextInferenceResultFieldMapper.INFERENCE_RESULTS, chunks); - inferenceMetadataMap.put(fieldResponse.field, fieldMap); } indexRequest.source(newDocMap); } @@ -294,7 +281,7 @@ private Map> createFieldInferenceRequests(Bu Map> fieldRequestsMap = new LinkedHashMap<>(); for (var item : bulkShardRequest.items()) { if (item.getPrimaryResponse() != null) { - // item was already aborted/processed by a filter in the chain upstream (e.g. security). + // item was already aborted/processed by a filter in the chain upstream (e.g. security) continue; } final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); @@ -302,30 +289,38 @@ private Map> createFieldInferenceRequests(Bu continue; } final Map docMap = indexRequest.sourceAsMap(); - List fieldRequests = null; - for (var pair : fieldInferenceMetadata.entrySet()) { - String inferenceId = pair.getKey(); - for (var field : pair.getValue()) { + for (var entry : fieldInferenceMetadata.entrySet()) { + String inferenceId = entry.getKey(); + for (var field : entry.getValue()) { var value = XContentMapValues.extractValue(field, docMap); if (value == null) { continue; } - if (value instanceof String valueStr) { - if (inferenceResults.get(item.id()) == null) { - inferenceResults.set( + if (inferenceResults.get(item.id()) == null) { + inferenceResults.set( + item.id(), + new FieldInferenceResponseAccumulator( item.id(), - new FieldInferenceResponseAccumulator( - item.id(), - Collections.synchronizedList(new ArrayList<>()), - Collections.synchronizedList(new ArrayList<>()) - ) - ); - } - if (fieldRequests == null) { - fieldRequests = new ArrayList<>(); - fieldRequestsMap.put(inferenceId, fieldRequests); - } + Collections.synchronizedList(new ArrayList<>()), + Collections.synchronizedList(new ArrayList<>()) + ) + ); + } + if (value instanceof String valueStr) { + List fieldRequests = fieldRequestsMap.computeIfAbsent( + inferenceId, + k -> new ArrayList<>() + ); fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); + } else { + inferenceResults.get(item.id()).failures.add( + new ElasticsearchStatusException( + "Invalid format for field [{}], expected [String] got [{}]", + RestStatus.BAD_REQUEST, + field, + value.getClass().getSimpleName() + ) + ); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java similarity index 86% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java index cee6395185060..2ede5419ab74e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java @@ -8,6 +8,8 @@ package org.elasticsearch.xpack.inference.mapper; import org.apache.lucene.search.Query; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.Strings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.DocumentParserContext; @@ -27,15 +29,24 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -51,7 +62,7 @@ * { * "_source": { * "my_semantic_text_field": "these are not the droids you're looking for", - * "_semantic_text_inference": { + * "_inference": { * "my_semantic_text_field": [ * { * "sparse_embedding": { @@ -94,17 +105,17 @@ * } * */ -public class SemanticTextInferenceResultFieldMapper extends MetadataFieldMapper { +public class InferenceResultFieldMapper extends MetadataFieldMapper { public static final String NAME = "_inference"; public static final String CONTENT_TYPE = "_inference"; - public static final String INFERENCE_RESULTS = "results"; + public static final String RESULTS = "results"; public static final String INFERENCE_CHUNKS_RESULTS = "inference"; public static final String INFERENCE_CHUNKS_TEXT = "text"; - public static final TypeParser PARSER = new FixedTypeParser(c -> new SemanticTextInferenceResultFieldMapper()); + public static final TypeParser PARSER = new FixedTypeParser(c -> new InferenceResultFieldMapper()); - private static final Logger logger = LogManager.getLogger(SemanticTextInferenceResultFieldMapper.class); + private static final Logger logger = LogManager.getLogger(InferenceResultFieldMapper.class); private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); @@ -131,7 +142,7 @@ public Query termQuery(Object value, SearchExecutionContext context) { } } - private SemanticTextInferenceResultFieldMapper() { + public InferenceResultFieldMapper() { super(SemanticTextInferenceFieldType.INSTANCE); } @@ -172,7 +183,7 @@ private static void parseSingleField(DocumentParserContext context, MapperBuilde failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); String currentName = parser.currentName(); - if (INFERENCE_RESULTS.equals(currentName)) { + if (RESULTS.equals(currentName)) { NestedObjectMapper nestedObjectMapper = createInferenceResultsObjectMapper( context, mapperBuilderContext, @@ -328,4 +339,34 @@ protected String contentType() { public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { return SourceLoader.SyntheticFieldLoader.NOTHING; } + + public static void applyFieldInference( + Map inferenceMap, + String field, + Model model, + ChunkedInferenceServiceResults results + ) throws ElasticsearchException { + List> chunks = new ArrayList<>(); + if (results instanceof ChunkedSparseEmbeddingResults textExpansionResults) { + for (var chunk : textExpansionResults.getChunkedResults()) { + chunks.add(chunk.asMap()); + } + } else if (results instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { + for (var chunk : textEmbeddingResults.getChunks()) { + chunks.add(chunk.asMap()); + } + } else { + throw new ElasticsearchStatusException( + "Invalid inference results format for field [{}] with inference id [{}], got {}", + RestStatus.BAD_REQUEST, + field, + model.getInferenceEntityId(), + results.getWriteableName() + ); + } + Map fieldMap = new LinkedHashMap<>(); + fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); + fieldMap.put(InferenceResultFieldMapper.RESULTS, chunks); + inferenceMap.put(field, fieldMap); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 027b85a9a9f45..4caa3d68ba877 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -30,7 +30,7 @@ * at ingestion and query time. * For now, it is compatible with text expansion models only, but will be extended to support dense vector models as well. * This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will - * be indexed using {@link SemanticTextInferenceResultFieldMapper}. + * be indexed using {@link InferenceResultFieldMapper}. */ public class SemanticTextFieldMapper extends FieldMapper { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java new file mode 100644 index 0000000000000..7f3ffbe596543 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -0,0 +1,340 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action.filter; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.bulk.BulkItemRequest; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkShardRequest; +import org.elasticsearch.action.bulk.TransportShardBulkAction; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.ActionFilterChain; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.model.TestModel; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.After; +import org.junit.Before; +import org.mockito.stubbing.Answer; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomSparseEmbeddings; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomTextEmbeddings; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ShardBulkInferenceActionFilterTests extends ESTestCase { + private ThreadPool threadPool; + + @Before + public void setupThreadPool() { + threadPool = new TestThreadPool(getTestName()); + } + + @After + public void tearDownThreadPool() throws Exception { + terminate(threadPool); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testFilterNoop() throws Exception { + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of()); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + assertNull(((BulkShardRequest) request).getFieldsInferenceMetadata()); + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + BulkShardRequest request = new BulkShardRequest( + new ShardId("test", "test", 0), + WriteRequest.RefreshPolicy.NONE, + new BulkItemRequest[0] + ); + request.setFieldInferenceMetadata(Map.of("foo", Set.of("bar"))); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testInferenceNotFound() throws Exception { + StaticModel model = randomStaticModel(); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(model.getInferenceEntityId(), model)); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getFieldsInferenceMetadata()); + for (BulkItemRequest item : bulkShardRequest.items()) { + assertNotNull(item.getPrimaryResponse()); + assertTrue(item.getPrimaryResponse().isFailed()); + BulkItemResponse.Failure failure = item.getPrimaryResponse().getFailure(); + assertThat(failure.getStatus(), equalTo(RestStatus.NOT_FOUND)); + } + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + + Map> inferenceFields = Map.of( + model.getInferenceEntityId(), + Set.of("field1"), + "inference_0", + Set.of("field2", "field3") + ); + BulkItemRequest[] items = new BulkItemRequest[10]; + for (int i = 0; i < items.length; i++) { + items[i] = randomBulkItemRequest(i, Map.of(), inferenceFields)[0]; + } + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); + request.setFieldInferenceMetadata(inferenceFields); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testManyRandomDocs() throws Exception { + Map inferenceModelMap = new HashMap<>(); + int numModels = randomIntBetween(1, 5); + for (int i = 0; i < numModels; i++) { + StaticModel model = randomStaticModel(); + inferenceModelMap.put(model.getInferenceEntityId(), model); + } + + int numInferenceFields = randomIntBetween(1, 5); + Map> inferenceFields = new HashMap<>(); + for (int i = 0; i < numInferenceFields; i++) { + String inferenceId = randomFrom(inferenceModelMap.keySet()); + String field = randomAlphaOfLengthBetween(5, 10); + var res = inferenceFields.computeIfAbsent(inferenceId, k -> new HashSet<>()); + res.add(field); + } + + int numRequests = randomIntBetween(100, 1000); + BulkItemRequest[] originalRequests = new BulkItemRequest[numRequests]; + BulkItemRequest[] modifiedRequests = new BulkItemRequest[numRequests]; + for (int id = 0; id < numRequests; id++) { + BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, inferenceFields); + originalRequests[id] = res[0]; + modifiedRequests[id] = res[1]; + } + + ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + assertThat(request, instanceOf(BulkShardRequest.class)); + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getFieldsInferenceMetadata()); + BulkItemRequest[] items = bulkShardRequest.items(); + assertThat(items.length, equalTo(originalRequests.length)); + for (int id = 0; id < items.length; id++) { + IndexRequest actualRequest = ShardBulkInferenceActionFilter.getIndexRequestOrNull(items[id].request()); + IndexRequest expectedRequest = ShardBulkInferenceActionFilter.getIndexRequestOrNull(modifiedRequests[id].request()); + try { + assertToXContentEquivalent(expectedRequest.source(), actualRequest.source(), actualRequest.getContentType()); + } catch (Exception exc) { + throw new IllegalStateException(exc); + } + } + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + BulkShardRequest original = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, originalRequests); + original.setFieldInferenceMetadata(inferenceFields); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, original, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings("unchecked") + private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool, Map modelMap) { + ModelRegistry modelRegistry = mock(ModelRegistry.class); + Answer unparsedModelAnswer = invocationOnMock -> { + String id = (String) invocationOnMock.getArguments()[0]; + ActionListener listener = (ActionListener) invocationOnMock + .getArguments()[1]; + var model = modelMap.get(id); + if (model != null) { + listener.onResponse( + new ModelRegistry.UnparsedModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getServiceSettings().model(), + XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(model.getTaskSettings()), false), + XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(model.getSecretSettings()), false) + ) + ); + } else { + listener.onFailure(new ResourceNotFoundException("model id [{}] not found", id)); + } + return null; + }; + doAnswer(unparsedModelAnswer).when(modelRegistry).getModelWithSecrets(any(), any()); + + InferenceService inferenceService = mock(InferenceService.class); + Answer chunkedInferAnswer = invocationOnMock -> { + StaticModel model = (StaticModel) invocationOnMock.getArguments()[0]; + List inputs = (List) invocationOnMock.getArguments()[1]; + ActionListener> listener = (ActionListener< + List>) invocationOnMock.getArguments()[5]; + Runnable runnable = () -> { + List results = new ArrayList<>(); + for (String input : inputs) { + results.add(model.getResults(input)); + } + listener.onResponse(results); + }; + if (randomBoolean()) { + try { + threadPool.generic().execute(runnable); + } catch (Exception exc) { + listener.onFailure(exc); + } + } else { + runnable.run(); + } + return null; + }; + doAnswer(chunkedInferAnswer).when(inferenceService).chunkedInfer(any(), any(), any(), any(), any(), any()); + + Answer modelAnswer = invocationOnMock -> { + String inferenceId = (String) invocationOnMock.getArguments()[0]; + return modelMap.get(inferenceId); + }; + doAnswer(modelAnswer).when(inferenceService).parsePersistedConfigWithSecrets(any(), any(), any(), any()); + + InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); + when(inferenceServiceRegistry.getService(any())).thenReturn(Optional.of(inferenceService)); + ShardBulkInferenceActionFilter filter = new ShardBulkInferenceActionFilter(inferenceServiceRegistry, modelRegistry); + return filter; + } + + private static BulkItemRequest[] randomBulkItemRequest( + int id, + Map modelMap, + Map> inferenceFieldMap + ) { + Map docMap = new LinkedHashMap<>(); + Map inferenceResultsMap = new LinkedHashMap<>(); + for (var entry : inferenceFieldMap.entrySet()) { + String inferenceId = entry.getKey(); + var model = modelMap.get(inferenceId); + for (var field : entry.getValue()) { + String text = randomAlphaOfLengthBetween(10, 100); + docMap.put(field, text); + if (model == null) { + // ignore results, the doc should fail with a resource not found exception + continue; + } + int numChunks = randomIntBetween(1, 5); + List chunks = new ArrayList<>(); + for (int i = 0; i < numChunks; i++) { + chunks.add(randomAlphaOfLengthBetween(5, 10)); + } + TaskType taskType = model.getTaskType(); + final ChunkedInferenceServiceResults results; + switch (taskType) { + case TEXT_EMBEDDING: + results = randomTextEmbeddings(chunks); + break; + + case SPARSE_EMBEDDING: + results = randomSparseEmbeddings(chunks); + break; + + default: + throw new AssertionError("Unknown task type " + taskType.name()); + } + model.putResult(text, results); + InferenceResultFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); + } + } + Map expectedDocMap = new LinkedHashMap<>(docMap); + expectedDocMap.put(InferenceResultFieldMapper.NAME, inferenceResultsMap); + return new BulkItemRequest[] { + new BulkItemRequest(id, new IndexRequest("index").source(docMap)), + new BulkItemRequest(id, new IndexRequest("index").source(expectedDocMap)) }; + } + + private static StaticModel randomStaticModel() { + String serviceName = randomAlphaOfLengthBetween(5, 10); + String inferenceId = randomAlphaOfLengthBetween(5, 10); + return new StaticModel( + inferenceId, + randomBoolean() ? TaskType.TEXT_EMBEDDING : TaskType.SPARSE_EMBEDDING, + serviceName, + new TestModel.TestServiceSettings("my-model"), + new TestModel.TestTaskSettings(randomIntBetween(1, 100)), + new TestModel.TestSecretSettings(randomAlphaOfLength(10)) + ); + } + + private static class StaticModel extends TestModel { + private final Map resultMap; + + StaticModel( + String inferenceEntityId, + TaskType taskType, + String service, + TestServiceSettings serviceSettings, + TestTaskSettings taskSettings, + TestSecretSettings secretSettings + ) { + super(inferenceEntityId, taskType, service, serviceSettings, taskSettings, secretSettings); + this.resultMap = new HashMap<>(); + } + + ChunkedInferenceServiceResults getResults(String text) { + return resultMap.getOrDefault(text, new ChunkedSparseEmbeddingResults(List.of())); + } + + void putResult(String text, ChunkedInferenceServiceResults results) { + resultMap.put(text, results); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java similarity index 79% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java index 5dc245298838f..b5d75b528c6ab 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextInferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java @@ -31,48 +31,46 @@ import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ParsedDocument; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.LeafNestedDocuments; import org.elasticsearch.search.NestedDocuments; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.model.TestModel; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Consumer; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper.INFERENCE_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.RESULTS; import static org.hamcrest.Matchers.containsString; -public class SemanticTextInferenceResultFieldMapperTests extends MetadataMapperTestCase { - private record SemanticTextInferenceResults(String fieldName, SparseEmbeddingResults sparseEmbeddingResults, List text) { - private SemanticTextInferenceResults { - if (sparseEmbeddingResults.embeddings().size() != text.size()) { - throw new IllegalArgumentException("Sparse embeddings and text must be the same size"); - } - } - } +public class InferenceResultFieldMapperTests extends MetadataMapperTestCase { + private record SemanticTextInferenceResults(String fieldName, ChunkedInferenceServiceResults results, List text) {} - private record VisitedChildDocInfo(String path, int sparseVectorDims) {} + private record VisitedChildDocInfo(String path, int numChunks) {} private record SparseVectorSubfieldOptions(boolean include, boolean includeEmbedding, boolean includeIsTruncated) {} @Override protected String fieldName() { - return SemanticTextInferenceResultFieldMapper.NAME; + return InferenceResultFieldMapper.NAME; } @Override @@ -108,8 +106,8 @@ public void testSuccessfulParse() throws IOException { b -> addSemanticTextInferenceResults( b, List.of( - generateSemanticTextinferenceResults(fieldName1, List.of("a b", "c")), - generateSemanticTextinferenceResults(fieldName2, List.of("d e f")) + randomSemanticTextInferenceResults(fieldName1, List.of("a b", "c")), + randomSemanticTextInferenceResults(fieldName2, List.of("d e f")) ) ) ) @@ -208,10 +206,10 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), true, - null + Map.of() ) ) ) @@ -226,10 +224,10 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), new SparseVectorSubfieldOptions(true, true, true), false, - null + Map.of() ) ) ) @@ -244,10 +242,10 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), false, - null + Map.of() ) ) ) @@ -262,7 +260,7 @@ public void testMissingSubfields() throws IOException { public void testExtraSubfields() throws IOException { final String fieldName = randomAlphaOfLengthBetween(5, 15); final List semanticTextInferenceResultsList = List.of( - generateSemanticTextinferenceResults(fieldName, List.of("a b")) + randomSemanticTextInferenceResults(fieldName, List.of("a b")) ); DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, fieldName, randomAlphaOfLength(8)))); @@ -360,7 +358,7 @@ public void testMissingSemanticTextMapping() throws IOException { DocumentParsingException.class, DocumentParsingException.class, () -> documentMapper.parse( - source(b -> addSemanticTextInferenceResults(b, List.of(generateSemanticTextinferenceResults(fieldName, List.of("a b"))))) + source(b -> addSemanticTextInferenceResults(b, List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))))) ) ); assertThat( @@ -378,18 +376,32 @@ private static void addSemanticTextMapping(XContentBuilder mappingBuilder, Strin mappingBuilder.endObject(); } - private static SemanticTextInferenceResults generateSemanticTextinferenceResults(String semanticTextFieldName, List chunks) { - List embeddings = new ArrayList<>(chunks.size()); - for (String chunk : chunks) { - String[] tokens = chunk.split("\\s+"); - List weightedTokens = Arrays.stream(tokens) - .map(t -> new SparseEmbeddingResults.WeightedToken(t, randomFloat())) - .toList(); + public static ChunkedTextEmbeddingResults randomTextEmbeddings(List inputs) { + List chunks = new ArrayList<>(); + for (String input : inputs) { + double[] values = new double[5]; + for (int j = 0; j < values.length; j++) { + values[j] = randomDouble(); + } + chunks.add(new org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk(input, values)); + } + return new ChunkedTextEmbeddingResults(chunks); + } - embeddings.add(new SparseEmbeddingResults.Embedding(weightedTokens, false)); + public static ChunkedSparseEmbeddingResults randomSparseEmbeddings(List inputs) { + List chunks = new ArrayList<>(); + for (String input : inputs) { + var tokens = new ArrayList(); + for (var token : input.split("\\s+")) { + tokens.add(new TextExpansionResults.WeightedToken(token, randomFloat())); + } + chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input, tokens)); } + return new ChunkedSparseEmbeddingResults(chunks); + } - return new SemanticTextInferenceResults(semanticTextFieldName, new SparseEmbeddingResults(embeddings), chunks); + private static SemanticTextInferenceResults randomSemanticTextInferenceResults(String semanticTextFieldName, List chunks) { + return new SemanticTextInferenceResults(semanticTextFieldName, randomSparseEmbeddings(chunks), chunks); } private static void addSemanticTextInferenceResults( @@ -401,10 +413,11 @@ private static void addSemanticTextInferenceResults( semanticTextInferenceResults, new SparseVectorSubfieldOptions(true, true, true), true, - null + Map.of() ); } + @SuppressWarnings("unchecked") private static void addSemanticTextInferenceResults( XContentBuilder sourceBuilder, List semanticTextInferenceResults, @@ -412,48 +425,39 @@ private static void addSemanticTextInferenceResults( boolean includeTextSubfield, Map extraSubfields ) throws IOException { - - Map> inferenceResultsMap = new HashMap<>(); + Map inferenceResultsMap = new HashMap<>(); for (SemanticTextInferenceResults semanticTextInferenceResult : semanticTextInferenceResults) { - Map fieldMap = new HashMap<>(); - fieldMap.put(SemanticTextModelSettings.NAME, modelSettingsMap()); - List> parsedInferenceResults = new ArrayList<>(semanticTextInferenceResult.text().size()); - - Iterator embeddingsIterator = semanticTextInferenceResult.sparseEmbeddingResults() - .embeddings() - .iterator(); - Iterator textIterator = semanticTextInferenceResult.text().iterator(); - while (embeddingsIterator.hasNext() && textIterator.hasNext()) { - SparseEmbeddingResults.Embedding embedding = embeddingsIterator.next(); - String text = textIterator.next(); - - Map subfieldMap = new HashMap<>(); - if (sparseVectorSubfieldOptions.include()) { - subfieldMap.put(INFERENCE_CHUNKS_RESULTS, embedding.asMap().get(SparseEmbeddingResults.Embedding.EMBEDDING)); - } - if (includeTextSubfield) { - subfieldMap.put(INFERENCE_CHUNKS_TEXT, text); + InferenceResultFieldMapper.applyFieldInference( + inferenceResultsMap, + semanticTextInferenceResult.fieldName, + randomModel(), + semanticTextInferenceResult.results + ); + Map optionsMap = (Map) inferenceResultsMap.get(semanticTextInferenceResult.fieldName); + List> fieldResultList = (List>) optionsMap.get(RESULTS); + for (var entry : fieldResultList) { + if (includeTextSubfield == false) { + entry.remove(INFERENCE_CHUNKS_TEXT); } - if (extraSubfields != null) { - subfieldMap.putAll(extraSubfields); + if (sparseVectorSubfieldOptions.include == false) { + entry.remove(INFERENCE_CHUNKS_RESULTS); } - - parsedInferenceResults.add(subfieldMap); + entry.putAll(extraSubfields); } - - fieldMap.put(INFERENCE_RESULTS, parsedInferenceResults); - inferenceResultsMap.put(semanticTextInferenceResult.fieldName(), fieldMap); } - - sourceBuilder.field(SemanticTextInferenceResultFieldMapper.NAME, inferenceResultsMap); + sourceBuilder.field(InferenceResultFieldMapper.NAME, inferenceResultsMap); } - private static Map modelSettingsMap() { - return Map.of( - SemanticTextModelSettings.TASK_TYPE_FIELD.getPreferredName(), - TaskType.SPARSE_EMBEDDING.toString(), - SemanticTextModelSettings.INFERENCE_ID_FIELD.getPreferredName(), - randomAlphaOfLength(8) + private static Model randomModel() { + String serviceName = randomAlphaOfLengthBetween(5, 10); + String inferenceId = randomAlphaOfLengthBetween(5, 10); + return new TestModel( + inferenceId, + TaskType.SPARSE_EMBEDDING, + serviceName, + new TestModel.TestServiceSettings("my-model"), + new TestModel.TestTaskSettings(randomIntBetween(1, 100)), + new TestModel.TestSecretSettings(randomAlphaOfLength(10)) ); }