From cfe1457543bafb5b25a4d2ef4fac6b904ceedc90 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Wed, 3 Apr 2024 17:42:52 +0200 Subject: [PATCH] remove the inference metadata field mapper and implements all the logic in the semantic text field mapper --- .../action/update/TransportUpdateAction.java | 16 +- .../metadata/InferenceFieldMetadata.java | 6 +- .../index/mapper/DocumentParser.java | 4 + .../xpack/inference/InferencePlugin.java | 10 - .../ShardBulkInferenceActionFilter.java | 223 ++++--- .../mapper/InferenceMetadataFieldMapper.java | 456 ------------- .../inference/mapper/SemanticTextField.java | 328 +++++++++ .../mapper/SemanticTextFieldMapper.java | 328 +++++---- .../mapper/SemanticTextModelSettings.java | 181 ----- .../ShardBulkInferenceActionFilterTests.java | 53 +- .../InferenceMetadataFieldMapperTests.java | 629 ------------------ .../mapper/SemanticTextFieldMapperTests.java | 299 ++++++++- .../mapper/SemanticTextFieldTests.java | 219 ++++++ .../inference/10_semantic_text_inference.yml | 133 ++-- .../20_semantic_text_field_mapper.yml | 20 - 15 files changed, 1222 insertions(+), 1683 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java diff --git a/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java b/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java index 36a47bc7e02e9..63ae56bfbd047 100644 --- a/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java +++ b/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java @@ -40,7 +40,6 @@ import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.IndexService; import org.elasticsearch.index.engine.VersionConflictEngineException; -import org.elasticsearch.index.mapper.InferenceFieldMapper; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndicesService; @@ -185,7 +184,7 @@ protected void shardOperation(final UpdateRequest request, final ActionListener< final UpdateHelper.Result result = updateHelper.prepare(request, indexShard, threadPool::absoluteTimeInMillis); switch (result.getResponseResult()) { case CREATED -> { - IndexRequest upsertRequest = removeInferenceMetadataField(indexService, result.action()); + IndexRequest upsertRequest = result.action(); // we fetch it from the index request so we don't generate the bytes twice, its already done in the index request final BytesReference upsertSourceBytes = upsertRequest.source(); client.bulk( @@ -227,7 +226,7 @@ protected void shardOperation(final UpdateRequest request, final ActionListener< ); } case UPDATED -> { - IndexRequest indexRequest = removeInferenceMetadataField(indexService, result.action()); + IndexRequest indexRequest = result.action(); // we fetch it from the index request so we don't generate the bytes twice, its already done in the index request final BytesReference indexSourceBytes = indexRequest.source(); client.bulk( @@ -336,15 +335,4 @@ private void handleUpdateFailureWithRetry( } listener.onFailure(cause instanceof Exception ? (Exception) cause : new NotSerializableExceptionWrapper(cause)); } - - private IndexRequest removeInferenceMetadataField(IndexService service, IndexRequest request) { - var inferenceMetadata = service.getIndexSettings().getIndexMetadata().getInferenceFields(); - if (inferenceMetadata.isEmpty()) { - return request; - } - Map docMap = request.sourceAsMap(); - docMap.remove(InferenceFieldMapper.NAME); - request.source(docMap); - return request; - } } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java index 947aa2c82640c..0cd3f05f250a3 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java @@ -54,12 +54,14 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; InferenceFieldMetadata that = (InferenceFieldMetadata) o; - return inferenceId.equals(that.inferenceId) && Arrays.equals(sourceFields, that.sourceFields); + return Objects.equals(name, that.name) + && Objects.equals(inferenceId, that.inferenceId) + && Arrays.equals(sourceFields, that.sourceFields); } @Override public int hashCode() { - int result = Objects.hash(inferenceId); + int result = Objects.hash(name, inferenceId); result = 31 * result + Arrays.hashCode(sourceFields); return result; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java index 1fda9ababfabd..7357f6f4bdfc6 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java @@ -696,6 +696,10 @@ private static void failIfMatchesRoutingPath(DocumentParserContext context, Stri */ private static void parseCopyFields(DocumentParserContext context, List copyToFields) throws IOException { for (String field : copyToFields) { + if (context.mappingLookup().getMapper(field) instanceof InferenceFieldMapper) { + // ignore copy_to that targets inference fields, values are already extracted in the coordinating node to perform inference. + continue; + } // In case of a hierarchy of nested documents, we need to figure out // which document the field should go to LuceneDocument targetDoc = null; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 494d6918b6086..666e7a3bd2043 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -23,7 +23,6 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.mapper.Mapper; -import org.elasticsearch.index.mapper.MetadataFieldMapper; import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; @@ -55,7 +54,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; @@ -282,14 +280,6 @@ public Map getMappers() { return Map.of(); } - @Override - public Map getMetadataMappers() { - if (SemanticTextFeature.isEnabled()) { - return Map.of(InferenceMetadataFieldMapper.NAME, InferenceMetadataFieldMapper.PARSER); - } - return Map.of(); - } - @Override public Collection getActionFilters() { if (SemanticTextFeature.isEnabled()) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 2e6f66c64fa95..e79e91f2e2114 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -37,59 +37,28 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; -import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; + /** - * A {@link MappedActionFilter} intercepting {@link BulkShardRequest}s to apply inference on fields declared as - * {@link SemanticTextFieldMapper} in the index mapping. - * The source of each {@link BulkItemRequest} requiring inference is augmented with the results for each field - * under the {@link InferenceMetadataFieldMapper#NAME} section. - * For example, for an index with a semantic_text field named {@code my_semantic_field} the following source document: - *
- *
- * {
- *      "my_semantic_text_field": "these are not the droids you're looking for"
- * }
- * 
- * is rewritten into: - *
- *
- * {
- *      "_inference": {
- *        "my_semantic_field": {
- *          "inference_id": "my_inference_id",
- *                  "model_settings": {
- *                      "task_type": "SPARSE_EMBEDDING"
- *                  },
- *                  "chunks": [
- *                      {
- *                             "inference": {
- *                                 "lucas": 0.05212344,
- *                                 "ty": 0.041213956,
- *                                 "dragon": 0.50991,
- *                                 "type": 0.23241979,
- *                                 "dr": 1.9312073,
- *                                 "##o": 0.2797593
- *                             },
- *                             "text": "these are not the droids you're looking for"
- *                       }
- *                  ]
- *        }
- *      }
- *      "my_semantic_field": "these are not the droids you're looking for"
- * }
- * 
- * The rewriting process occurs on the bulk coordinator node, and the results are then passed downstream - * to the {@link TransportShardBulkAction} for actual indexing. + * A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified + * as {@link SemanticTextFieldMapper} in the index mapping. For each semantic text field referencing fields in + * the request source, we generate embeddings and include the results in the source under the semantic text field + * name as a {@link SemanticTextField}. + * This transformation happens on the bulk coordinator node, and the {@link SemanticTextFieldMapper} parses the + * results during indexing on the shard. * * TODO: batchSize should be configurable via a cluster setting */ @@ -158,11 +127,52 @@ private void processBulkShardRequest( private record InferenceProvider(InferenceService service, Model model) {} - private record FieldInferenceRequest(int id, String field, String input) {} + /** + * A field inference request on a single input. + * @param id The id of the request in the original bulk request. + * @param field The target field. + * @param input The input to run inference on. + * @param inputOrder The original order of the input. + * @param isRawInput Whether the input is part of the raw values of the original field. + */ + private record FieldInferenceRequest(int id, String field, String input, int inputOrder, boolean isRawInput) {} - private record FieldInferenceResponse(String field, @Nullable Model model, @Nullable ChunkedInferenceServiceResults chunkedResults) {} + /** + * The field inference response. + * @param field The target field. + * @param input The input that was used to run inference. + * @param inputOrder The original order of the input. + * @param isRawInput Whether the input is part of the raw values of the original field. + * @param model The model used to run inference. + * @param chunkedResults The actual results. + */ + private record FieldInferenceResponse( + String field, + String input, + int inputOrder, + boolean isRawInput, + Model model, + ChunkedInferenceServiceResults chunkedResults + ) {} - private record FieldInferenceResponseAccumulator(int id, List responses, List failures) {} + private record FieldInferenceResponseAccumulator( + int id, + Map> responses, + List failures + ) { + void addOrUpdateResponse(FieldInferenceResponse response) { + synchronized (this) { + var list = responses.computeIfAbsent(response.field, k -> new ArrayList<>()); + list.add(response); + } + } + + void addFailure(Exception exc) { + synchronized (this) { + failures.add(exc); + } + } + } private class AsyncBulkShardInferenceAction implements Runnable { private final Map fieldInferenceMap; @@ -234,8 +244,8 @@ public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { var request = requests.get(i); inferenceResults.get(request.id).failures.add( new ResourceNotFoundException( - "Inference id [{}] not found for field [{}]", - inferenceId, + "Inference service [{}] not found for field [{}]", + unparsedModel.service(), request.field ) ); @@ -271,7 +281,16 @@ public void onResponse(List results) { var request = requests.get(i); var result = results.get(i); var acc = inferenceResults.get(request.id); - acc.responses.add(new FieldInferenceResponse(request.field, inferenceProvider.model, result)); + acc.addOrUpdateResponse( + new FieldInferenceResponse( + request.field(), + request.input(), + request.inputOrder(), + request.isRawInput(), + inferenceProvider.model, + result + ) + ); } } finally { onFinish(); @@ -283,7 +302,8 @@ public void onFailure(Exception exc) { try { for (int i = 0; i < requests.size(); i++) { var request = requests.get(i); - inferenceResults.get(request.id).failures.add( + addInferenceResponseFailure( + request.id, new ElasticsearchException( "Exception when running inference id [{}] on field [{}]", exc, @@ -319,11 +339,7 @@ private void onFinish() { private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) { FieldInferenceResponseAccumulator acc = inferenceResults.get(id); if (acc == null) { - acc = new FieldInferenceResponseAccumulator( - id, - Collections.synchronizedList(new ArrayList<>()), - Collections.synchronizedList(new ArrayList<>()) - ); + acc = new FieldInferenceResponseAccumulator(id, new HashMap<>(), new ArrayList<>()); inferenceResults.set(id, acc); } return acc; @@ -331,14 +347,14 @@ private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) private void addInferenceResponseFailure(int id, Exception failure) { var acc = ensureResponseAccumulatorSlot(id); - acc.failures().add(failure); + acc.addFailure(failure); } /** * Applies the {@link FieldInferenceResponseAccumulator} to the provided {@link BulkItemRequest}. * If the response contains failures, the bulk item request is marked as failed for the downstream action. * Otherwise, the source of the request is augmented with the field inference results under the - * {@link InferenceMetadataFieldMapper#NAME} field. + * {@link SemanticTextFieldMapper#NAME} field. */ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) { if (response.failures().isEmpty() == false) { @@ -349,37 +365,41 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons } final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); - Map newDocMap = indexRequest.sourceAsMap(); - Object inferenceObj = newDocMap.computeIfAbsent(InferenceMetadataFieldMapper.NAME, k -> new LinkedHashMap()); - Map inferenceMap = XContentMapValues.nodeMapValue(inferenceObj, InferenceMetadataFieldMapper.NAME); - newDocMap.put(InferenceMetadataFieldMapper.NAME, inferenceMap); - for (FieldInferenceResponse fieldResponse : response.responses()) { - if (fieldResponse.chunkedResults != null) { - try { - InferenceMetadataFieldMapper.applyFieldInference( - inferenceMap, - fieldResponse.field(), - fieldResponse.model(), - fieldResponse.chunkedResults() - ); - } catch (Exception exc) { - item.abort(item.index(), exc); - } - } else { - inferenceMap.remove(fieldResponse.field); - } + var newDocMap = indexRequest.sourceAsMap(); + for (var entry : response.responses.entrySet()) { + var fieldName = entry.getKey(); + var responses = entry.getValue(); + var model = responses.get(0).model(); + // ensure that the order in the raw field is consistent in case of multiple inputs + Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder)); + List inputs = responses.stream().filter(r -> r.isRawInput).map(r -> r.input).collect(Collectors.toList()); + List results = entry.getValue() + .stream() + .map(r -> r.chunkedResults) + .collect(Collectors.toList()); + var result = new SemanticTextField( + fieldName, + inputs, + new SemanticTextField.InferenceResult( + model.getInferenceEntityId(), + new SemanticTextField.ModelSettings(model), + toSemanticTextFieldChunks(fieldName, model.getInferenceEntityId(), results, indexRequest.getContentType()) + ), + indexRequest.getContentType() + ); + newDocMap.put(fieldName, result); } - indexRequest.source(newDocMap); + indexRequest.source(newDocMap, indexRequest.getContentType()); } /** * Register a {@link FieldInferenceRequest} for every non-empty field referencing an inference ID in the index. - * If results are already populated for fields in the existing _inference object, - * the inference request for this specific field is skipped, and the existing results remain unchanged. - * Validation of inference ID and model settings occurs in the {@link InferenceMetadataFieldMapper} - * during field indexing, where an error will be thrown if they mismatch or if the content is malformed. + * If results are already populated for fields in the original index request, the inference request for this specific + * field is skipped, and the existing results remain unchanged. + * Validation of inference ID and model settings occurs in the {@link SemanticTextFieldMapper} during field indexing, + * where an error will be thrown if they mismatch or if the content is malformed. * - * TODO: Should we validate the settings for pre-existing results here and apply the inference only if they differ? + * TODO: We should validate the settings for pre-existing results here and apply the inference only if they differ? */ private Map> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) { Map> fieldRequestsMap = new LinkedHashMap<>(); @@ -411,17 +431,18 @@ private Map> createFieldInferenceRequests(Bu continue; } final Map docMap = indexRequest.sourceAsMap(); - final Map inferenceMap = XContentMapValues.nodeMapValue( - docMap.computeIfAbsent(InferenceMetadataFieldMapper.NAME, k -> new LinkedHashMap()), - InferenceMetadataFieldMapper.NAME - ); for (var entry : fieldInferenceMap.values()) { String field = entry.getName(); String inferenceId = entry.getInferenceId(); - Object inferenceResult = inferenceMap.remove(field); + var rawValue = XContentMapValues.extractValue(field, docMap); + if (rawValue instanceof Map) { + continue; + } + int order = 0; for (var sourceField : entry.getSourceFields()) { - var value = XContentMapValues.extractValue(sourceField, docMap); - if (value == null) { + boolean isRawField = sourceField.equals(field); + var valueObj = XContentMapValues.extractValue(sourceField, docMap); + if (valueObj == null) { if (isUpdateRequest) { addInferenceResponseFailure( item.id(), @@ -432,26 +453,25 @@ private Map> createFieldInferenceRequests(Bu field ) ); - } else if (inferenceResult != null) { - addInferenceResponseFailure( - item.id(), - new ElasticsearchStatusException( - "The field [{}] is referenced in the [{}] metadata field but has no value", - RestStatus.BAD_REQUEST, - field, - InferenceMetadataFieldMapper.NAME - ) - ); + break; } continue; } ensureResponseAccumulatorSlot(item.id()); - if (value instanceof String valueStr) { + if (valueObj instanceof String valueStr) { + List fieldRequests = fieldRequestsMap.computeIfAbsent( + inferenceId, + k -> new ArrayList<>() + ); + fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr, order++, isRawField)); + } else if (valueObj instanceof List valueList) { List fieldRequests = fieldRequestsMap.computeIfAbsent( inferenceId, k -> new ArrayList<>() ); - fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); + for (var value : valueList) { + fieldRequests.add(new FieldInferenceRequest(item.id(), field, value.toString(), order++, isRawField)); + } } else { addInferenceResponseFailure( item.id(), @@ -459,9 +479,10 @@ private Map> createFieldInferenceRequests(Bu "Invalid format for field [{}], expected [String] got [{}]", RestStatus.BAD_REQUEST, field, - value.getClass().getSimpleName() + valueObj.getClass().getSimpleName() ) ); + break; } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java deleted file mode 100644 index 89d1037243aac..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java +++ /dev/null @@ -1,456 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.mapper; - -import org.apache.lucene.search.Query; -import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.xcontent.support.XContentMapValues; -import org.elasticsearch.index.mapper.DocumentParserContext; -import org.elasticsearch.index.mapper.DocumentParsingException; -import org.elasticsearch.index.mapper.FieldMapper; -import org.elasticsearch.index.mapper.InferenceFieldMapper; -import org.elasticsearch.index.mapper.MappedFieldType; -import org.elasticsearch.index.mapper.Mapper; -import org.elasticsearch.index.mapper.MapperBuilderContext; -import org.elasticsearch.index.mapper.MetadataFieldMapper; -import org.elasticsearch.index.mapper.NestedObjectMapper; -import org.elasticsearch.index.mapper.ObjectMapper; -import org.elasticsearch.index.mapper.SourceLoader; -import org.elasticsearch.index.mapper.SourceValueFetcher; -import org.elasticsearch.index.mapper.TextSearchInfo; -import org.elasticsearch.index.mapper.ValueFetcher; -import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xcontent.DeprecationHandler; -import org.elasticsearch.xcontent.NamedXContentRegistry; -import org.elasticsearch.xcontent.XContentLocation; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xcontent.support.MapXContentParser; -import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.canMergeModelSettings; - -/** - * A mapper for the {@code _inference} field. - *
- *
- * This mapper works in tandem with {@link SemanticTextFieldMapper semantic_text} fields to index inference results. - * The inference results for {@code semantic_text} fields are written to {@code _source} by an upstream process like so: - *
- *
- *
- * {
- *     "_source": {
- *         "my_semantic_text_field": "these are not the droids you're looking for",
- *         "_inference": {
- *             "my_semantic_text_field": {
- *                  "inference_id": "my_inference_id",
- *                  "model_settings": {
- *                      "task_type": "SPARSE_EMBEDDING"
- *                  },
- *                  "chunks" [
- *                      {
- *                          "inference": {
- *                              "lucas": 0.05212344,
- *                              "ty": 0.041213956,
- *                              "dragon": 0.50991,
- *                              "type": 0.23241979,
- *                              "dr": 1.9312073,
- *                              "##o": 0.2797593
- *                          },
- *                          "text": "these are not the droids you're looking for"
- *                      }
- *                  ]
- *              }
- *          }
- *      }
- * }
- * 
- * - * This mapper parses the contents of the {@code _inference} field and indexes it as if the mapping were configured like so: - *
- *
- *
- * {
- *     "mappings": {
- *         "properties": {
- *             "my_semantic_field": {
- *                 "chunks": {
- *                      "type": "nested",
- *                      "properties": {
- *                          "embedding": {
- *                              "type": "sparse_vector|dense_vector"
- *                          },
- *                          "text": {
- *                              "type": "keyword",
- *                              "index": false,
- *                              "doc_values": false
- *                          }
- *                     }
- *                 }
- *             }
- *         }
- *     }
- * }
- * 
- */ -public class InferenceMetadataFieldMapper extends MetadataFieldMapper { - public static final String NAME = InferenceFieldMapper.NAME; - public static final String CONTENT_TYPE = "_inference"; - - public static final String INFERENCE_ID = "inference_id"; - public static final String CHUNKS = "chunks"; - public static final String INFERENCE_CHUNKS_RESULTS = "inference"; - public static final String INFERENCE_CHUNKS_TEXT = "text"; - - public static final TypeParser PARSER = new FixedTypeParser(c -> new InferenceMetadataFieldMapper()); - - private static final Logger logger = LogManager.getLogger(InferenceMetadataFieldMapper.class); - - private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); - - static class SemanticTextInferenceFieldType extends MappedFieldType { - private static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); - - SemanticTextInferenceFieldType() { - super(NAME, true, false, false, TextSearchInfo.NONE, Collections.emptyMap()); - } - - @Override - public String typeName() { - return CONTENT_TYPE; - } - - @Override - public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.identity(name(), context, format); - } - - @Override - public Query termQuery(Object value, SearchExecutionContext context) { - return null; - } - } - - public InferenceMetadataFieldMapper() { - super(SemanticTextInferenceFieldType.INSTANCE); - } - - @Override - protected void parseCreateField(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.START_OBJECT); - boolean origWithLeafObject = context.path().isWithinLeafObject(); - try { - // make sure that we don't expand dots in field names while parsing - context.path().setWithinLeafObject(true); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.FIELD_NAME); - parseSingleField(context); - } - } finally { - context.path().setWithinLeafObject(origWithLeafObject); - } - } - - private NestedObjectMapper updateSemanticTextFieldMapper( - DocumentParserContext docContext, - SemanticTextMapperContext semanticFieldContext, - String newInferenceId, - SemanticTextModelSettings newModelSettings, - XContentLocation xContentLocation - ) { - final String fullFieldName = semanticFieldContext.mapper.fieldType().name(); - final String inferenceId = semanticFieldContext.mapper.getInferenceId(); - if (newInferenceId.equals(inferenceId) == false) { - throw new DocumentParsingException( - xContentLocation, - Strings.format( - "The configured %s [%s] for field [%s] doesn't match the %s [%s] reported in the document.", - INFERENCE_ID, - inferenceId, - fullFieldName, - INFERENCE_ID, - newInferenceId - ) - ); - } - if (newModelSettings.taskType() == TaskType.TEXT_EMBEDDING && newModelSettings.dimensions() == null) { - throw new DocumentParsingException( - xContentLocation, - "Model settings for field [" + fullFieldName + "] must contain dimensions" - ); - } - if (semanticFieldContext.mapper.getModelSettings() == null) { - SemanticTextFieldMapper newMapper = new SemanticTextFieldMapper.Builder( - semanticFieldContext.mapper.simpleName(), - docContext.indexSettings().getIndexVersionCreated() - ).setInferenceId(newInferenceId).setModelSettings(newModelSettings).build(semanticFieldContext.context); - docContext.addDynamicMapper(newMapper); - return newMapper.getSubMappers(); - } else { - SemanticTextFieldMapper.Conflicts conflicts = new Conflicts(fullFieldName); - canMergeModelSettings(semanticFieldContext.mapper.getModelSettings(), newModelSettings, conflicts); - try { - conflicts.check(); - } catch (Exception exc) { - throw new DocumentParsingException(xContentLocation, "Incompatible model_settings", exc); - } - } - return semanticFieldContext.mapper.getSubMappers(); - } - - private void parseSingleField(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - String fieldName = parser.currentName(); - SemanticTextMapperContext builderContext = createSemanticFieldContext(context, fieldName); - if (builderContext == null) { - throw new DocumentParsingException( - parser.getTokenLocation(), - Strings.format("Field [%s] is not registered as a [%s] field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) - ); - } - parser.nextToken(); - failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.START_OBJECT); - - // record the location of the inference field in the original source - XContentLocation xContentLocation = parser.getTokenLocation(); - // parse eagerly to extract the inference id and the model settings first - Map map = parser.mapOrdered(); - - // inference_id - Object inferenceIdObj = map.remove(INFERENCE_ID); - final String inferenceId = XContentMapValues.nodeStringValue(inferenceIdObj, null); - if (inferenceId == null) { - throw new IllegalArgumentException("required [" + INFERENCE_ID + "] is missing"); - } - - // model_settings - Object modelSettingsObj = map.remove(SemanticTextModelSettings.NAME); - if (modelSettingsObj == null) { - throw new DocumentParsingException( - parser.getTokenLocation(), - Strings.format( - "Missing required [%s] for field [%s] of type [%s]", - SemanticTextModelSettings.NAME, - fieldName, - SemanticTextFieldMapper.CONTENT_TYPE - ) - ); - } - final SemanticTextModelSettings modelSettings; - try { - modelSettings = SemanticTextModelSettings.fromMap(modelSettingsObj); - } catch (Exception exc) { - throw new DocumentParsingException( - xContentLocation, - Strings.format( - "Error parsing [%s] for field [%s] of type [%s]", - SemanticTextModelSettings.NAME, - fieldName, - SemanticTextFieldMapper.CONTENT_TYPE - ), - exc - ); - } - - var nestedObjectMapper = updateSemanticTextFieldMapper(context, builderContext, inferenceId, modelSettings, xContentLocation); - - // we know the model settings, so we can (re) parse the results array now - XContentParser subParser = new MapXContentParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.IGNORE_DEPRECATIONS, - map, - XContentType.JSON - ); - DocumentParserContext mapContext = context.switchParser(subParser); - parseFieldInference(xContentLocation, subParser, mapContext, nestedObjectMapper); - } - - private void parseFieldInference( - XContentLocation xContentLocation, - XContentParser parser, - DocumentParserContext context, - NestedObjectMapper nestedMapper - ) throws IOException { - parser.nextToken(); - failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_OBJECT); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - switch (parser.currentName()) { - case CHUNKS -> parseChunks(xContentLocation, parser, context, nestedMapper); - default -> throw new DocumentParsingException(xContentLocation, "Unknown field name " + parser.currentName()); - } - } - } - - private void parseChunks( - XContentLocation xContentLocation, - XContentParser parser, - DocumentParserContext context, - NestedObjectMapper nestedMapper - ) throws IOException { - parser.nextToken(); - failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_ARRAY); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_ARRAY; token = parser.nextToken()) { - DocumentParserContext subContext = context.createNestedContext(nestedMapper); - parseResultsObject(xContentLocation, parser, subContext, nestedMapper); - } - } - - private void parseResultsObject( - XContentLocation xContentLocation, - XContentParser parser, - DocumentParserContext context, - NestedObjectMapper nestedMapper - ) throws IOException { - failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_OBJECT); - Set visited = new HashSet<>(); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.FIELD_NAME); - visited.add(parser.currentName()); - FieldMapper fieldMapper = (FieldMapper) nestedMapper.getMapper(parser.currentName()); - if (fieldMapper == null) { - if (REQUIRED_SUBFIELDS.contains(parser.currentName())) { - throw new DocumentParsingException( - xContentLocation, - "Missing sub-fields definition for [" + parser.currentName() + "]" - ); - } else { - logger.debug("Skipping indexing of unrecognized field name [" + parser.currentName() + "]"); - advancePastCurrentFieldName(xContentLocation, parser); - continue; - } - } - parser.nextToken(); - fieldMapper.parse(context); - // Reset leaf object after parsing the field - context.path().setWithinLeafObject(true); - } - if (visited.containsAll(REQUIRED_SUBFIELDS) == false) { - Set missingSubfields = REQUIRED_SUBFIELDS.stream() - .filter(s -> visited.contains(s) == false) - .collect(Collectors.toSet()); - throw new DocumentParsingException(xContentLocation, "Missing required subfields: " + missingSubfields); - } - } - - private static void failIfTokenIsNot(XContentLocation xContentLocation, XContentParser parser, XContentParser.Token expected) { - if (parser.currentToken() != expected) { - throw new DocumentParsingException(xContentLocation, "Expected a " + expected.toString() + ", got " + parser.currentToken()); - } - } - - private static void advancePastCurrentFieldName(XContentLocation xContentLocation, XContentParser parser) throws IOException { - assert parser.currentToken() == XContentParser.Token.FIELD_NAME; - XContentParser.Token token = parser.nextToken(); - if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) { - parser.skipChildren(); - } else if (token.isValue() == false && token != XContentParser.Token.VALUE_NULL) { - throw new DocumentParsingException(xContentLocation, "Expected a START_* or VALUE_*, got " + token); - } - } - - @Override - protected String contentType() { - return CONTENT_TYPE; - } - - @Override - public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { - return SourceLoader.SyntheticFieldLoader.NOTHING; - } - - @SuppressWarnings("unchecked") - public static void applyFieldInference( - Map inferenceMap, - String field, - Model model, - ChunkedInferenceServiceResults results - ) throws ElasticsearchException { - List> chunks = new ArrayList<>(); - if (results instanceof ChunkedSparseEmbeddingResults textExpansionResults) { - for (var chunk : textExpansionResults.getChunkedResults()) { - chunks.add(chunk.asMap()); - } - } else if (results instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { - for (var chunk : textEmbeddingResults.getChunks()) { - chunks.add(chunk.asMap()); - } - } else { - throw new ElasticsearchStatusException( - "Invalid inference results format for field [{}] with inference id [{}], got {}", - RestStatus.BAD_REQUEST, - field, - model.getInferenceEntityId(), - results.getWriteableName() - ); - } - - Map fieldMap = (Map) inferenceMap.computeIfAbsent(field, s -> new LinkedHashMap<>()); - fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); - List> fieldChunks = (List>) fieldMap.computeIfAbsent(CHUNKS, k -> new ArrayList<>()); - fieldChunks.addAll(chunks); - fieldMap.put(INFERENCE_ID, model.getInferenceEntityId()); - } - - record SemanticTextMapperContext(MapperBuilderContext context, SemanticTextFieldMapper mapper) {} - - /** - * Returns the {@link SemanticTextFieldMapper} associated with the provided {@code fullName} - * and the {@link MapperBuilderContext} that was used to build it. - * If the field is not found or is of the wrong type, this method returns {@code null}. - */ - static SemanticTextMapperContext createSemanticFieldContext(DocumentParserContext docContext, String fullName) { - ObjectMapper rootMapper = docContext.mappingLookup().getMapping().getRoot(); - return createSemanticFieldContext(MapperBuilderContext.root(false, false), rootMapper, fullName.split("\\.")); - } - - static SemanticTextMapperContext createSemanticFieldContext( - MapperBuilderContext mapperContext, - ObjectMapper objectMapper, - String[] paths - ) { - Mapper mapper = objectMapper.getMapper(paths[0]); - if (mapper instanceof ObjectMapper newObjectMapper) { - mapperContext = mapperContext.createChildContext(paths[0], ObjectMapper.Dynamic.FALSE); - return createSemanticFieldContext(mapperContext, newObjectMapper, Arrays.copyOfRange(paths, 1, paths.length)); - } else if (mapper instanceof SemanticTextFieldMapper semanticMapper) { - return new SemanticTextMapperContext(mapperContext, semanticMapper); - } else { - if (mapper == null || paths.length == 1) { - return null; - } - // check if the semantic field is defined within a multi-field - Mapper fieldMapper = objectMapper.getMapper(String.join(".", Arrays.asList(paths))); - if (fieldMapper instanceof SemanticTextFieldMapper semanticMapper) { - return new SemanticTextMapperContext(mapperContext, semanticMapper); - } - } - return null; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java new file mode 100644 index 0000000000000..a69f98d4a230a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -0,0 +1,328 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.DeprecationHandler; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.support.MapXContentParser; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/** + * A {@link ToXContentObject} that is used to represent the transformation of the semantic text field's inputs. + * The resulting object preserves the original input under the {@link SemanticTextField#RAW_FIELD} and exposes + * the inference results under the {@link SemanticTextField#INFERENCE_FIELD}. + * + * @param fieldName The original field name. + * @param raw The raw values associated with the field name. + * @param inference The inference result. + * @param contentType The {@link XContentType} used to store the embeddings chunks. + */ +public record SemanticTextField(String fieldName, List raw, InferenceResult inference, XContentType contentType) + implements + ToXContentObject { + + static final ParseField RAW_FIELD = new ParseField("raw"); + static final ParseField INFERENCE_FIELD = new ParseField("inference"); + static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); + static final ParseField CHUNKS_FIELD = new ParseField("chunks"); + static final ParseField CHUNKED_EMBEDDINGS_FIELD = new ParseField("embeddings"); + static final ParseField CHUNKED_TEXT_FIELD = new ParseField("text"); + static final ParseField MODEL_SETTINGS_FIELD = new ParseField("model_settings"); + static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); + static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); + static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); + + public record InferenceResult(String inferenceId, ModelSettings modelSettings, List chunks) {} + + public record Chunk(String text, BytesReference rawEmbeddings) {} + + public record ModelSettings(TaskType taskType, Integer dimensions, SimilarityMeasure similarity) implements ToXContentObject { + public ModelSettings(Model model) { + this(model.getTaskType(), model.getServiceSettings().dimensions(), model.getServiceSettings().similarity()); + } + + public ModelSettings(TaskType taskType, Integer dimensions, SimilarityMeasure similarity) { + this.taskType = Objects.requireNonNull(taskType, "task type must not be null"); + this.dimensions = dimensions; + this.similarity = similarity; + validate(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); + if (dimensions != null) { + builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); + } + if (similarity != null) { + builder.field(SIMILARITY_FIELD.getPreferredName(), similarity); + } + return builder.endObject(); + } + + private void validate() { + switch (taskType) { + case TEXT_EMBEDDING: + if (dimensions == null) { + throw new IllegalArgumentException( + "required [" + DIMENSIONS_FIELD + "] field is missing for task_type [" + taskType.name() + "]" + ); + } + if (similarity == null) { + throw new IllegalArgumentException( + "required [" + SIMILARITY_FIELD + "] field is missing for task_type [" + taskType.name() + "]" + ); + } + break; + case SPARSE_EMBEDDING: + break; + + default: + throw new IllegalArgumentException( + "Wrong [" + + TASK_TYPE_FIELD.getPreferredName() + + "], expected " + + TEXT_EMBEDDING + + " or " + + SPARSE_EMBEDDING + + ", got " + + taskType.name() + ); + } + } + } + + public static String getRawFieldName(String fieldName) { + return fieldName + "." + RAW_FIELD.getPreferredName(); + } + + public static String getInferenceFieldName(String fieldName) { + return fieldName + "." + INFERENCE_FIELD.getPreferredName(); + } + + public static String getChunksFieldName(String fieldName) { + return getInferenceFieldName(fieldName) + "." + CHUNKS_FIELD.getPreferredName(); + } + + public static String getEmbeddingsFieldName(String fieldName) { + return getChunksFieldName(fieldName) + "." + CHUNKED_EMBEDDINGS_FIELD.getPreferredName(); + } + + static SemanticTextField parse(XContentParser parser, Tuple context) throws IOException { + return SEMANTIC_TEXT_FIELD_PARSER.parse(parser, context); + } + + static ModelSettings parseModelSettings(XContentParser parser) throws IOException { + return MODEL_SETTINGS_PARSER.parse(parser, null); + } + + static ModelSettings parseModelSettingsFromMap(Object node) { + if (node == null) { + return null; + } + try { + Map map = XContentMapValues.nodeMapValue(node, MODEL_SETTINGS_FIELD.getPreferredName()); + XContentParser parser = new MapXContentParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + map, + XContentType.JSON + ); + return parseModelSettings(parser); + } catch (Exception exc) { + throw new ElasticsearchException(exc); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (raw.isEmpty() == false) { + builder.field(RAW_FIELD.getPreferredName(), raw.size() == 1 ? raw.get(0) : raw); + } + builder.startObject(INFERENCE_FIELD.getPreferredName()); + builder.field(INFERENCE_ID_FIELD.getPreferredName(), inference.inferenceId); + builder.field(MODEL_SETTINGS_FIELD.getPreferredName(), inference.modelSettings); + builder.startArray(CHUNKS_FIELD.getPreferredName()); + for (var chunk : inference.chunks) { + builder.startObject(); + builder.field(CHUNKED_TEXT_FIELD.getPreferredName(), chunk.text); + XContentParser parser = XContentHelper.createParserNotCompressed( + XContentParserConfiguration.EMPTY, + chunk.rawEmbeddings, + contentType + ); + builder.field(CHUNKED_EMBEDDINGS_FIELD.getPreferredName()).copyCurrentStructure(parser); + builder.endObject(); + } + builder.endArray(); + builder.endObject(); + builder.endObject(); + return builder; + } + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser> SEMANTIC_TEXT_FIELD_PARSER = + new ConstructingObjectParser<>( + "semantic", + true, + (args, context) -> new SemanticTextField( + context.v1(), + (List) (args[0] == null ? List.of() : args[0]), + (InferenceResult) args[1], + context.v2() + ) + ); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser INFERENCE_RESULT_PARSER = new ConstructingObjectParser<>( + "inference", + true, + args -> new InferenceResult((String) args[0], (ModelSettings) args[1], (List) args[2]) + ); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser CHUNKS_PARSER = new ConstructingObjectParser<>( + "chunks", + true, + args -> new Chunk((String) args[0], (BytesReference) args[1]) + ); + + private static final ConstructingObjectParser MODEL_SETTINGS_PARSER = new ConstructingObjectParser<>( + "model_settings", + true, + args -> { + TaskType taskType = TaskType.fromString((String) args[0]); + Integer dimensions = (Integer) args[1]; + SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]); + return new ModelSettings(taskType, dimensions, similarity); + } + ); + + static { + SEMANTIC_TEXT_FIELD_PARSER.declareStringArray(optionalConstructorArg(), RAW_FIELD); + SEMANTIC_TEXT_FIELD_PARSER.declareObject(constructorArg(), (p, c) -> INFERENCE_RESULT_PARSER.parse(p, null), INFERENCE_FIELD); + + INFERENCE_RESULT_PARSER.declareString(constructorArg(), INFERENCE_ID_FIELD); + INFERENCE_RESULT_PARSER.declareObject(constructorArg(), (p, c) -> MODEL_SETTINGS_PARSER.parse(p, c), MODEL_SETTINGS_FIELD); + INFERENCE_RESULT_PARSER.declareObjectArray(constructorArg(), (p, c) -> CHUNKS_PARSER.parse(p, c), CHUNKS_FIELD); + + CHUNKS_PARSER.declareString(constructorArg(), CHUNKED_TEXT_FIELD); + CHUNKS_PARSER.declareField(constructorArg(), (p, c) -> { + XContentBuilder b = XContentBuilder.builder(p.contentType().xContent()); + b.copyCurrentStructure(p); + return BytesReference.bytes(b); + }, CHUNKED_EMBEDDINGS_FIELD, ObjectParser.ValueType.OBJECT_ARRAY); + + MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_TYPE_FIELD); + MODEL_SETTINGS_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), DIMENSIONS_FIELD); + MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), SIMILARITY_FIELD); + } + + /** + * Converts the provided {@link ChunkedInferenceServiceResults} into a list of {@link Chunk}. + */ + public static List toSemanticTextFieldChunks( + String field, + String inferenceId, + List results, + XContentType contentType + ) { + List chunks = new ArrayList<>(); + for (var result : results) { + if (result instanceof ChunkedSparseEmbeddingResults textExpansionResults) { + for (var chunk : textExpansionResults.getChunkedResults()) { + chunks.add(new Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.weightedTokens()))); + } + } else if (result instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { + for (var chunk : textEmbeddingResults.getChunks()) { + chunks.add(new Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.embedding()))); + } + } else { + throw new ElasticsearchStatusException( + "Invalid inference results format for field [{}] with inference id [{}], got {}", + RestStatus.BAD_REQUEST, + field, + inferenceId, + result.getWriteableName() + ); + } + } + return chunks; + } + + /** + * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. + */ + private static BytesReference toBytesReference(XContent xContent, double[] value) { + try { + XContentBuilder b = XContentBuilder.builder(xContent); + b.startArray(); + for (double v : value) { + b.value(v); + } + b.endArray(); + return BytesReference.bytes(b); + } catch (IOException exc) { + throw new RuntimeException(exc); + } + } + + /** + * Serialises the {@link TextExpansionResults.WeightedToken} list, according to the provided {@link XContent}, + * into a {@link BytesReference}. + */ + private static BytesReference toBytesReference(XContent xContent, List tokens) { + try { + XContentBuilder b = XContentBuilder.builder(xContent); + b.startObject(); + for (var weightedToken : tokens) { + weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS); + } + b.endObject(); + return BytesReference.bytes(b); + } catch (IOException exc) { + throw new RuntimeException(exc); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index f8fde0b63e4ea..c80c84d414dba 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -9,11 +9,16 @@ import org.apache.lucene.search.Query; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.common.Explicit; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Tuple; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.DocumentParserContext; +import org.elasticsearch.index.mapper.DocumentParsingException; import org.elasticsearch.index.mapper.FieldMapper; import org.elasticsearch.index.mapper.InferenceFieldMapper; import org.elasticsearch.index.mapper.KeywordFieldMapper; @@ -35,9 +40,13 @@ import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentLocation; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -45,103 +54,33 @@ import java.util.Set; import java.util.function.Function; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.CHUNKS; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_TEXT_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_ID_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getRawFieldName; /** * A {@link FieldMapper} for semantic text fields. - * These fields have a reference id reference, that is used for performing inference at ingestion and query time. - * This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will - * be indexed using {@link InferenceMetadataFieldMapper}. */ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper { - private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class); - public static final String CONTENT_TYPE = "semantic_text"; - private static SemanticTextFieldMapper toType(FieldMapper in) { - return (SemanticTextFieldMapper) in; - } + private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class); public static final TypeParser PARSER = new TypeParser( (n, c) -> new Builder(n, c.indexVersionCreated()), notInMultiFields(CONTENT_TYPE) ); - private final IndexVersion indexVersionCreated; - private final String inferenceId; - private final SemanticTextModelSettings modelSettings; - private final NestedObjectMapper subMappers; - - private SemanticTextFieldMapper( - String simpleName, - MappedFieldType mappedFieldType, - CopyTo copyTo, - IndexVersion indexVersionCreated, - String inferenceId, - SemanticTextModelSettings modelSettings, - NestedObjectMapper subMappers - ) { - super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); - this.indexVersionCreated = indexVersionCreated; - this.inferenceId = inferenceId; - this.modelSettings = modelSettings; - this.subMappers = subMappers; - } - - @Override - public Iterator iterator() { - List subIterators = new ArrayList<>(); - subIterators.add(subMappers); - return subIterators.iterator(); - } - - @Override - public FieldMapper.Builder getMergeBuilder() { - return new Builder(simpleName(), indexVersionCreated).init(this); - } - - @Override - protected void parseCreateField(DocumentParserContext context) throws IOException { - // Just parses text - no indexing is performed - context.parser().textOrNull(); - } - - @Override - protected String contentType() { - return CONTENT_TYPE; - } - - @Override - public SemanticTextFieldType fieldType() { - return (SemanticTextFieldType) super.fieldType(); - } - - public String getInferenceId() { - return inferenceId; - } - - public SemanticTextModelSettings getModelSettings() { - return modelSettings; - } - - public NestedObjectMapper getSubMappers() { - return subMappers; - } - - @Override - public InferenceFieldMetadata getMetadata(Set sourcePaths) { - return new InferenceFieldMetadata(name(), inferenceId, sourcePaths.toArray(String[]::new)); - } - public static class Builder extends FieldMapper.Builder { private final IndexVersion indexVersionCreated; private final Parameter inferenceId = Parameter.stringParam( "inference_id", false, - m -> toType(m).fieldType().inferenceId, + mapper -> ((SemanticTextFieldType) mapper.fieldType()).inferenceId, null ).addValidator(v -> { if (Strings.isEmpty(v)) { @@ -149,24 +88,24 @@ public static class Builder extends FieldMapper.Builder { } }); - private final Parameter modelSettings = new Parameter<>( + private final Parameter modelSettings = new Parameter<>( "model_settings", true, () -> null, - (n, c, o) -> SemanticTextModelSettings.fromMap(o), - mapper -> ((SemanticTextFieldMapper) mapper).modelSettings, + (n, c, o) -> SemanticTextField.parseModelSettingsFromMap(o), + mapper -> ((SemanticTextFieldType) mapper.fieldType()).modelSettings, XContentBuilder::field, (m) -> m == null ? "null" : Strings.toString(m) ).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeModelSettings); private final Parameter> meta = Parameter.metaParam(); - private Function subFieldsFunction; + private Function inferenceFieldBuilder; public Builder(String name, IndexVersion indexVersionCreated) { super(name); this.indexVersionCreated = indexVersionCreated; - this.subFieldsFunction = c -> createSubFields(c); + this.inferenceFieldBuilder = c -> createInferenceField(c, indexVersionCreated, modelSettings.get()); } public Builder setInferenceId(String id) { @@ -174,7 +113,7 @@ public Builder setInferenceId(String id) { return this; } - public Builder setModelSettings(SemanticTextModelSettings value) { + public Builder setModelSettings(SemanticTextField.ModelSettings value) { this.modelSettings.setValue(value); return this; } @@ -188,63 +127,156 @@ protected Parameter[] getParameters() { protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeContext mapperMergeContext) { super.merge(mergeWith, conflicts, mapperMergeContext); conflicts.check(); - SemanticTextFieldMapper semanticMergeWith = (SemanticTextFieldMapper) mergeWith; - var childMergeContext = mapperMergeContext.createChildContext(name(), ObjectMapper.Dynamic.FALSE); - NestedObjectMapper mergedSubFields = (NestedObjectMapper) semanticMergeWith.getSubMappers() - .merge( - subFieldsFunction.apply(childMergeContext.getMapperBuilderContext()), - MapperService.MergeReason.MAPPING_UPDATE, - childMergeContext - ); - subFieldsFunction = c -> mergedSubFields; + var semanticMergeWith = (SemanticTextFieldMapper) mergeWith; + var context = mapperMergeContext.createChildContext(mergeWith.simpleName(), ObjectMapper.Dynamic.FALSE); + var inferenceField = inferenceFieldBuilder.apply(context.getMapperBuilderContext()); + var childContext = context.createChildContext(inferenceField.simpleName(), ObjectMapper.Dynamic.FALSE); + var mergedInferenceField = inferenceField.merge( + semanticMergeWith.fieldType().getInferenceField(), + MapperService.MergeReason.MAPPING_UPDATE, + childContext + ); + inferenceFieldBuilder = c -> mergedInferenceField; } @Override public SemanticTextFieldMapper build(MapperBuilderContext context) { final String fullName = context.buildFullName(name()); var childContext = context.createChildContext(name(), ObjectMapper.Dynamic.FALSE); - final NestedObjectMapper subFields = subFieldsFunction.apply(childContext); + final ObjectMapper inferenceField = inferenceFieldBuilder.apply(childContext); return new SemanticTextFieldMapper( name(), - new SemanticTextFieldType(fullName, inferenceId.getValue(), modelSettings.getValue(), subFields, meta.getValue()), - copyTo, - indexVersionCreated, - inferenceId.getValue(), - modelSettings.getValue(), - subFields + new SemanticTextFieldType( + fullName, + inferenceId.getValue(), + modelSettings.getValue(), + inferenceField, + indexVersionCreated, + meta.getValue() + ), + copyTo ); } + } + + private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, CopyTo copyTo) { + super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); + } + + @Override + public Iterator iterator() { + List subIterators = new ArrayList<>(); + subIterators.add(fieldType().getInferenceField()); + return subIterators.iterator(); + } - private NestedObjectMapper createSubFields(MapperBuilderContext context) { - NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder(CHUNKS, indexVersionCreated); - nestedBuilder.dynamic(ObjectMapper.Dynamic.FALSE); - KeywordFieldMapper.Builder textMapperBuilder = new KeywordFieldMapper.Builder(INFERENCE_CHUNKS_TEXT, indexVersionCreated) - .indexed(false) - .docValues(false); - if (modelSettings.get() != null) { - nestedBuilder.add(createInferenceMapperBuilder(INFERENCE_CHUNKS_RESULTS, modelSettings.get(), indexVersionCreated)); + @Override + public FieldMapper.Builder getMergeBuilder() { + return new Builder(simpleName(), fieldType().indexVersionCreated).init(this); + } + + @Override + protected void parseCreateField(DocumentParserContext context) throws IOException { + XContentParser parser = context.parser(); + if (parser.currentToken() == XContentParser.Token.VALUE_NULL) { + return; + } + XContentLocation xContentLocation = parser.getTokenLocation(); + final SemanticTextField field; + boolean isWithinLeaf = context.path().isWithinLeafObject(); + try { + context.path().setWithinLeafObject(true); + field = SemanticTextField.parse(parser, new Tuple<>(name(), context.parser().contentType())); + } finally { + context.path().setWithinLeafObject(isWithinLeaf); + } + final String fullFieldName = fieldType().name(); + if (field.inference().inferenceId().equals(fieldType().getInferenceId()) == false) { + throw new DocumentParsingException( + xContentLocation, + Strings.format( + "The configured %s [%s] for field [%s] doesn't match the %s [%s] reported in the document.", + INFERENCE_ID_FIELD.getPreferredName(), + field.inference().inferenceId(), + fullFieldName, + INFERENCE_ID_FIELD.getPreferredName(), + fieldType().getInferenceId() + ) + ); + } + final SemanticTextFieldMapper mapper; + if (fieldType().getModelSettings() == null) { + context.path().remove(); + Builder builder = (Builder) new Builder(simpleName(), fieldType().indexVersionCreated).init(this); + try { + mapper = builder.setModelSettings(field.inference().modelSettings()) + .setInferenceId(field.inference().inferenceId()) + .build(context.createDynamicMapperBuilderContext()); + context.addDynamicMapper(mapper); + } finally { + context.path().add(simpleName()); + } + } else { + SemanticTextFieldMapper.Conflicts conflicts = new Conflicts(fullFieldName); + canMergeModelSettings(field.inference().modelSettings(), fieldType().getModelSettings(), conflicts); + try { + conflicts.check(); + } catch (Exception exc) { + throw new DocumentParsingException(xContentLocation, "Incompatible model_settings", exc); } - nestedBuilder.add(textMapperBuilder); - return nestedBuilder.build(context); + mapper = this; + } + var chunksField = mapper.fieldType().getChunksField(); + var embeddingsField = mapper.fieldType().getEmbeddingsField(); + for (var chunk : field.inference().chunks()) { + XContentParser subParser = XContentHelper.createParserNotCompressed( + XContentParserConfiguration.EMPTY, + chunk.rawEmbeddings(), + context.parser().contentType() + ); + DocumentParserContext subContext = context.createNestedContext(chunksField).switchParser(subParser); + subParser.nextToken(); + embeddingsField.parse(subContext); } } + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + public SemanticTextFieldType fieldType() { + return (SemanticTextFieldType) super.fieldType(); + } + + @Override + public InferenceFieldMetadata getMetadata(Set sourcePaths) { + String[] copyFields = sourcePaths.toArray(String[]::new); + // ensure consistent order + Arrays.sort(copyFields); + return new InferenceFieldMetadata(name(), fieldType().inferenceId, copyFields); + } + public static class SemanticTextFieldType extends SimpleMappedFieldType { private final String inferenceId; - private final SemanticTextModelSettings modelSettings; - private final NestedObjectMapper subMappers; + private final SemanticTextField.ModelSettings modelSettings; + private final ObjectMapper inferenceField; + private final IndexVersion indexVersionCreated; public SemanticTextFieldType( String name, String modelId, - SemanticTextModelSettings modelSettings, - NestedObjectMapper subMappers, + SemanticTextField.ModelSettings modelSettings, + ObjectMapper inferenceField, + IndexVersion indexVersionCreated, Map meta ) { super(name, false, false, false, TextSearchInfo.NONE, meta); this.inferenceId = modelId; this.modelSettings = modelSettings; - this.subMappers = subMappers; + this.inferenceField = inferenceField; + this.indexVersionCreated = indexVersionCreated; } @Override @@ -256,22 +288,31 @@ public String getInferenceId() { return inferenceId; } - public SemanticTextModelSettings getModelSettings() { + public SemanticTextField.ModelSettings getModelSettings() { return modelSettings; } - public NestedObjectMapper getSubMappers() { - return subMappers; + public ObjectMapper getInferenceField() { + return inferenceField; + } + + public NestedObjectMapper getChunksField() { + return (NestedObjectMapper) inferenceField.getMapper(CHUNKS_FIELD.getPreferredName()); + } + + public FieldMapper getEmbeddingsField() { + return (FieldMapper) getChunksField().getMapper(CHUNKED_EMBEDDINGS_FIELD.getPreferredName()); } @Override public Query termQuery(Object value, SearchExecutionContext context) { - throw new IllegalArgumentException("termQuery not implemented yet"); + throw new IllegalArgumentException(CONTENT_TYPE + " fields do not support term query"); } @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.toString(name(), context, format); + // Redirect the fetcher to load the value from the raw field + return SourceValueFetcher.toString(getRawFieldName(name()), context, format); } @Override @@ -280,16 +321,39 @@ public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext } } - private static Mapper.Builder createInferenceMapperBuilder( - String fieldName, - SemanticTextModelSettings modelSettings, - IndexVersion indexVersionCreated + private static ObjectMapper createInferenceField( + MapperBuilderContext context, + IndexVersion indexVersionCreated, + @Nullable SemanticTextField.ModelSettings modelSettings + ) { + return new ObjectMapper.Builder(INFERENCE_FIELD.getPreferredName(), Explicit.EXPLICIT_TRUE).dynamic(ObjectMapper.Dynamic.FALSE) + .add(createChunksField(indexVersionCreated, modelSettings)) + .build(context); + } + + private static NestedObjectMapper.Builder createChunksField( + IndexVersion indexVersionCreated, + SemanticTextField.ModelSettings modelSettings ) { + NestedObjectMapper.Builder chunksField = new NestedObjectMapper.Builder(CHUNKS_FIELD.getPreferredName(), indexVersionCreated); + chunksField.dynamic(ObjectMapper.Dynamic.FALSE); + KeywordFieldMapper.Builder chunkTextField = new KeywordFieldMapper.Builder( + CHUNKED_TEXT_FIELD.getPreferredName(), + indexVersionCreated + ).indexed(false).docValues(false); + if (modelSettings != null) { + chunksField.add(createEmbeddingsField(indexVersionCreated, modelSettings)); + } + chunksField.add(chunkTextField); + return chunksField; + } + + private static Mapper.Builder createEmbeddingsField(IndexVersion indexVersionCreated, SemanticTextField.ModelSettings modelSettings) { return switch (modelSettings.taskType()) { - case SPARSE_EMBEDDING -> new SparseVectorFieldMapper.Builder(INFERENCE_CHUNKS_RESULTS); + case SPARSE_EMBEDDING -> new SparseVectorFieldMapper.Builder(CHUNKED_EMBEDDINGS_FIELD.getPreferredName()); case TEXT_EMBEDDING -> { DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( - INFERENCE_CHUNKS_RESULTS, + CHUNKED_EMBEDDINGS_FIELD.getPreferredName(), indexVersionCreated ); SimilarityMeasure similarity = modelSettings.similarity(); @@ -298,22 +362,20 @@ private static Mapper.Builder createInferenceMapperBuilder( case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); default -> throw new IllegalArgumentException( - "Unknown similarity measure for field [" + fieldName + "] in model settings: " + similarity + "Unknown similarity measure in model_settings [" + similarity.name() + "]" ); } } denseVectorMapperBuilder.dimensions(modelSettings.dimensions()); yield denseVectorMapperBuilder; } - default -> throw new IllegalArgumentException( - "Invalid [task_type] for [" + fieldName + "] in model settings: " + modelSettings.taskType().name() - ); + default -> throw new IllegalArgumentException("Invalid task_type in model_settings [" + modelSettings.taskType().name() + "]"); }; } - static boolean canMergeModelSettings( - SemanticTextModelSettings previous, - SemanticTextModelSettings current, + private static boolean canMergeModelSettings( + SemanticTextField.ModelSettings previous, + SemanticTextField.ModelSettings current, FieldMapper.Conflicts conflicts ) { if (Objects.equals(previous, current)) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java deleted file mode 100644 index b1d0511008db8..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java +++ /dev/null @@ -1,181 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.mapper; - -import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.common.xcontent.support.XContentMapValues; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.DeprecationHandler; -import org.elasticsearch.xcontent.NamedXContentRegistry; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xcontent.support.MapXContentParser; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; - -import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; -import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; - -/** - * Serialization class for specifying the settings of a model from semantic_text inference to field mapper. - */ -public class SemanticTextModelSettings implements ToXContentObject { - - public static final String NAME = "model_settings"; - public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); - public static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); - public static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); - private final TaskType taskType; - private final Integer dimensions; - private final SimilarityMeasure similarity; - - public SemanticTextModelSettings(Model model) { - this(model.getTaskType(), model.getServiceSettings().dimensions(), model.getServiceSettings().similarity()); - } - - public SemanticTextModelSettings(TaskType taskType, Integer dimensions, SimilarityMeasure similarity) { - Objects.requireNonNull(taskType, "task type must not be null"); - this.taskType = taskType; - this.dimensions = dimensions; - this.similarity = similarity; - validate(); - } - - public static SemanticTextModelSettings parse(XContentParser parser) throws IOException { - return PARSER.apply(parser, null); - } - - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - NAME, - true, - args -> { - TaskType taskType = TaskType.fromString((String) args[0]); - Integer dimensions = (Integer) args[1]; - SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]); - return new SemanticTextModelSettings(taskType, dimensions, similarity); - } - ); - static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_TYPE_FIELD); - PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), DIMENSIONS_FIELD); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), SIMILARITY_FIELD); - } - - public static SemanticTextModelSettings fromMap(Object node) { - if (node == null) { - return null; - } - try { - Map map = XContentMapValues.nodeMapValue(node, NAME); - if (map.containsKey(TASK_TYPE_FIELD.getPreferredName()) == false) { - throw new IllegalArgumentException( - "Failed to parse [" + NAME + "], required [" + TASK_TYPE_FIELD.getPreferredName() + "] is missing" - ); - } - XContentParser parser = new MapXContentParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.IGNORE_DEPRECATIONS, - map, - XContentType.JSON - ); - return SemanticTextModelSettings.parse(parser); - } catch (Exception exc) { - throw new ElasticsearchException(exc); - } - } - - public Map asMap() { - Map attrsMap = new HashMap<>(); - attrsMap.put(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); - if (dimensions != null) { - attrsMap.put(DIMENSIONS_FIELD.getPreferredName(), dimensions); - } - if (similarity != null) { - attrsMap.put(SIMILARITY_FIELD.getPreferredName(), similarity); - } - return Map.of(NAME, attrsMap); - } - - public TaskType taskType() { - return taskType; - } - - public Integer dimensions() { - return dimensions; - } - - public SimilarityMeasure similarity() { - return similarity; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); - if (dimensions != null) { - builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); - } - if (similarity != null) { - builder.field(SIMILARITY_FIELD.getPreferredName(), similarity); - } - return builder.endObject(); - } - - public void validate() { - switch (taskType) { - case TEXT_EMBEDDING: - if (dimensions == null) { - throw new IllegalArgumentException( - "required [" + DIMENSIONS_FIELD + "] field is missing for task_type [" + taskType.name() + "]" - ); - } - if (similarity == null) { - throw new IllegalArgumentException( - "required [" + SIMILARITY_FIELD + "] field is missing for task_type [" + taskType.name() + "]" - ); - } - break; - case SPARSE_EMBEDDING: - break; - - default: - throw new IllegalArgumentException( - "Wrong [" - + TASK_TYPE_FIELD.getPreferredName() - + "], expected " - + TEXT_EMBEDDING - + " or " - + SPARSE_EMBEDDING - + ", got " - + taskType.name() - ); - } - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - SemanticTextModelSettings that = (SemanticTextModelSettings) o; - return taskType == that.taskType && Objects.equals(dimensions, that.dimensions) && similarity == that.similarity; - } - - @Override - public int hashCode() { - return Objects.hash(taskType, dimensions, similarity); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index d734e9998734d..5cb2acfadc2f9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -30,9 +30,10 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.junit.After; @@ -51,8 +52,8 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomSparseEmbeddings; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomTextEmbeddings; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.toChunkedResult; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.any; @@ -267,43 +268,25 @@ private static BulkItemRequest[] randomBulkItemRequest( Map fieldInferenceMap ) { Map docMap = new LinkedHashMap<>(); - Map inferenceResultsMap = new LinkedHashMap<>(); + Map expectedDocMap = new LinkedHashMap<>(); for (var entry : fieldInferenceMap.values()) { String field = entry.getName(); var model = modelMap.get(entry.getInferenceId()); String text = randomAlphaOfLengthBetween(10, 100); docMap.put(field, text); + expectedDocMap.put(field, text); if (model == null) { // ignore results, the doc should fail with a resource not found exception continue; } - int numChunks = randomIntBetween(1, 5); - List chunks = new ArrayList<>(); - for (int i = 0; i < numChunks; i++) { - chunks.add(randomAlphaOfLengthBetween(5, 10)); - } - TaskType taskType = model.getTaskType(); - final ChunkedInferenceServiceResults results; - switch (taskType) { - case TEXT_EMBEDDING: - results = randomTextEmbeddings(model, chunks); - break; - - case SPARSE_EMBEDDING: - results = randomSparseEmbeddings(chunks); - break; - - default: - throw new AssertionError("Unknown task type " + taskType.name()); - } - model.putResult(text, results); - InferenceMetadataFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); + var result = randomSemanticText(field, model, List.of(text), randomFrom(XContentType.values())); + model.putResult(text, result); + expectedDocMap.put(field, result); } - Map expectedDocMap = new LinkedHashMap<>(docMap); - expectedDocMap.put(InferenceMetadataFieldMapper.NAME, inferenceResultsMap); + XContentType requestContentType = randomFrom(XContentType.values()); return new BulkItemRequest[] { - new BulkItemRequest(id, new IndexRequest("index").source(docMap)), - new BulkItemRequest(id, new IndexRequest("index").source(expectedDocMap)) }; + new BulkItemRequest(id, new IndexRequest("index").source(docMap, requestContentType)), + new BulkItemRequest(id, new IndexRequest("index").source(expectedDocMap, requestContentType)) }; } private static StaticModel randomStaticModel() { @@ -320,7 +303,7 @@ private static StaticModel randomStaticModel() { } private static class StaticModel extends TestModel { - private final Map resultMap; + private final Map resultMap; StaticModel( String inferenceEntityId, @@ -335,11 +318,15 @@ private static class StaticModel extends TestModel { } ChunkedInferenceServiceResults getResults(String text) { - return resultMap.getOrDefault(text, new ChunkedSparseEmbeddingResults(List.of())); + SemanticTextField result = resultMap.get(text); + if (result == null) { + return new ChunkedSparseEmbeddingResults(List.of()); + } + return toChunkedResult(result); } - void putResult(String text, ChunkedInferenceServiceResults results) { - resultMap.put(text, results); + void putResult(String text, SemanticTextField result) { + resultMap.put(text, result); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java deleted file mode 100644 index 37e4e5e774bec..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java +++ /dev/null @@ -1,629 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.mapper; - -import org.apache.lucene.document.FeatureField; -import org.apache.lucene.index.IndexableField; -import org.apache.lucene.index.Term; -import org.apache.lucene.search.BooleanClause; -import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.TermQuery; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.join.BitSetProducer; -import org.apache.lucene.search.join.QueryBitSetProducer; -import org.apache.lucene.search.join.ScoreMode; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.lucene.search.Queries; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.IndexVersions; -import org.elasticsearch.index.mapper.DocumentMapper; -import org.elasticsearch.index.mapper.DocumentParsingException; -import org.elasticsearch.index.mapper.LuceneDocument; -import org.elasticsearch.index.mapper.MapperService; -import org.elasticsearch.index.mapper.MetadataMapperTestCase; -import org.elasticsearch.index.mapper.NestedLookup; -import org.elasticsearch.index.mapper.NestedObjectMapper; -import org.elasticsearch.index.mapper.ParsedDocument; -import org.elasticsearch.index.search.ESToParentBlockJoinQuery; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.search.LeafNestedDocuments; -import org.elasticsearch.search.NestedDocuments; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; -import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.model.TestModel; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.function.Consumer; - -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.CHUNKS; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT; -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.equalTo; - -public class InferenceMetadataFieldMapperTests extends MetadataMapperTestCase { - private record SemanticTextInferenceResults(String fieldName, Model model, ChunkedInferenceServiceResults results, List text) {} - - private record VisitedChildDocInfo(String path) {} - - private record SparseVectorSubfieldOptions(boolean include, boolean includeEmbedding, boolean includeIsTruncated) {} - - @Override - protected String fieldName() { - return InferenceMetadataFieldMapper.NAME; - } - - @Override - protected boolean isConfigurable() { - return false; - } - - @Override - protected boolean isSupportedOn(IndexVersion version) { - return version.onOrAfter(IndexVersions.ES_VERSION_8_12_1); // TODO: Switch to ES_VERSION_8_14 when available - } - - @Override - protected void registerParameters(ParameterChecker checker) throws IOException { - - } - - @Override - protected Collection getPlugins() { - return List.of(new InferencePlugin(Settings.EMPTY)); - } - - public void testSuccessfulParse() throws IOException { - for (int depth = 1; depth < 4; depth++) { - final String fieldName1 = randomFieldName(depth); - final String fieldName2 = randomFieldName(depth + 1); - - Model model1 = randomModel(TaskType.SPARSE_EMBEDDING); - Model model2 = randomModel(TaskType.SPARSE_EMBEDDING); - XContentBuilder mapping = mapping(b -> { - addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId()); - addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId()); - }); - - MapperService mapperService = createMapperService(mapping); - SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName1, false); - SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName2, false); - DocumentMapper documentMapper = mapperService.documentMapper(); - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of( - randomSemanticTextInferenceResults(fieldName1, model1, List.of("a b", "c")), - randomSemanticTextInferenceResults(fieldName2, model2, List.of("d e f")) - ) - ) - ) - ); - - List luceneDocs = doc.docs(); - assertEquals(4, luceneDocs.size()); - for (int i = 0; i < 3; i++) { - assertEquals(doc.rootDoc(), luceneDocs.get(i).getParent()); - } - // nested docs are in reversed order - assertSparseFeatures(luceneDocs.get(0), fieldName1 + ".chunks.inference", 2); - assertSparseFeatures(luceneDocs.get(1), fieldName1 + ".chunks.inference", 1); - assertSparseFeatures(luceneDocs.get(2), fieldName2 + ".chunks.inference", 3); - assertEquals(doc.rootDoc(), luceneDocs.get(3)); - assertNull(luceneDocs.get(3).getParent()); - - withLuceneIndex(mapperService, iw -> iw.addDocuments(doc.docs()), reader -> { - NestedDocuments nested = new NestedDocuments( - mapperService.mappingLookup(), - QueryBitSetProducer::new, - IndexVersion.current() - ); - LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); - - Set visitedNestedIdentities = new HashSet<>(); - Set expectedVisitedNestedIdentities = Set.of( - new SearchHit.NestedIdentity(fieldName1 + "." + CHUNKS, 0, null), - new SearchHit.NestedIdentity(fieldName1 + "." + CHUNKS, 1, null), - new SearchHit.NestedIdentity(fieldName2 + "." + CHUNKS, 0, null) - ); - - assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); - assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); - - assertNull(leaf.advance(3)); - assertEquals(3, leaf.doc()); - assertEquals(3, leaf.rootDoc()); - assertNull(leaf.nestedIdentity()); - - IndexSearcher searcher = newSearcher(reader); - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery( - mapperService.mappingLookup().nestedLookup(), - fieldName1 + "." + CHUNKS, - List.of("a") - ), - 10 - ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery( - mapperService.mappingLookup().nestedLookup(), - fieldName1 + "." + CHUNKS, - List.of("a", "b") - ), - 10 - ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery( - mapperService.mappingLookup().nestedLookup(), - fieldName2 + "." + CHUNKS, - List.of("d") - ), - 10 - ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery( - mapperService.mappingLookup().nestedLookup(), - fieldName2 + "." + CHUNKS, - List.of("z") - ), - 10 - ); - assertEquals(0, topDocs.totalHits.value); - } - }); - } - } - - public void testMissingSubfields() throws IOException { - final String fieldName = randomAlphaOfLengthBetween(5, 15); - final Model model = randomModel(randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING); - - DocumentMapper documentMapper = createDocumentMapper( - mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId())) - ); - - { - DocumentParsingException ex = expectThrows( - DocumentParsingException.class, - DocumentParsingException.class, - () -> documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), - new SparseVectorSubfieldOptions(false, true, true), - true, - Map.of() - ) - ) - ) - ); - assertThat(ex.getMessage(), containsString("Missing required subfields: [" + INFERENCE_CHUNKS_RESULTS + "]")); - } - { - DocumentParsingException ex = expectThrows( - DocumentParsingException.class, - DocumentParsingException.class, - () -> documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), - new SparseVectorSubfieldOptions(true, true, true), - false, - Map.of() - ) - ) - ) - ); - assertThat(ex.getMessage(), containsString("Missing required subfields: [" + INFERENCE_CHUNKS_TEXT + "]")); - } - { - DocumentParsingException ex = expectThrows( - DocumentParsingException.class, - DocumentParsingException.class, - () -> documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), - new SparseVectorSubfieldOptions(false, true, true), - false, - Map.of() - ) - ) - ) - ); - assertThat( - ex.getMessage(), - containsString("Missing required subfields: [" + INFERENCE_CHUNKS_RESULTS + ", " + INFERENCE_CHUNKS_TEXT + "]") - ); - } - } - - public void testExtraSubfields() throws IOException { - final String fieldName = randomAlphaOfLengthBetween(5, 15); - final Model model = randomModel(randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING); - final List semanticTextInferenceResultsList = List.of( - randomSemanticTextInferenceResults(fieldName, model, List.of("a b")) - ); - - DocumentMapper documentMapper = createDocumentMapper( - mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId())) - ); - - Consumer checkParsedDocument = d -> { - Set visitedChildDocs = new HashSet<>(); - Set expectedVisitedChildDocs = Set.of(new VisitedChildDocInfo(fieldName + "." + CHUNKS)); - - List luceneDocs = d.docs(); - assertEquals(2, luceneDocs.size()); - assertValidChildDoc(luceneDocs.get(0), d.rootDoc(), visitedChildDocs); - assertEquals(d.rootDoc(), luceneDocs.get(1)); - assertNull(luceneDocs.get(1).getParent()); - assertEquals(expectedVisitedChildDocs, visitedChildDocs); - }; - - { - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - semanticTextInferenceResultsList, - new SparseVectorSubfieldOptions(true, true, true), - true, - Map.of("extra_key", "extra_value") - ) - ) - ); - - checkParsedDocument.accept(doc); - LuceneDocument childDoc = doc.docs().get(0); - assertEquals(0, childDoc.getFields(childDoc.getPath() + ".extra_key").size()); - } - { - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - semanticTextInferenceResultsList, - new SparseVectorSubfieldOptions(true, true, true), - true, - Map.of("extra_key", Map.of("k1", "v1")) - ) - ) - ); - - checkParsedDocument.accept(doc); - LuceneDocument childDoc = doc.docs().get(0); - assertEquals(0, childDoc.getFields(childDoc.getPath() + ".extra_key").size()); - } - { - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - semanticTextInferenceResultsList, - new SparseVectorSubfieldOptions(true, true, true), - true, - Map.of("extra_key", List.of("v1")) - ) - ) - ); - - checkParsedDocument.accept(doc); - LuceneDocument childDoc = doc.docs().get(0); - assertEquals(0, childDoc.getFields(childDoc.getPath() + ".extra_key").size()); - } - { - Map extraSubfields = new HashMap<>(); - extraSubfields.put("extra_key", null); - - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - semanticTextInferenceResultsList, - new SparseVectorSubfieldOptions(true, true, true), - true, - extraSubfields - ) - ) - ); - - checkParsedDocument.accept(doc); - LuceneDocument childDoc = doc.docs().get(0); - assertEquals(0, childDoc.getFields(childDoc.getPath() + ".extra_key").size()); - } - } - - public void testMissingSemanticTextMapping() throws IOException { - final String fieldName = randomAlphaOfLengthBetween(5, 15); - - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> {})); - DocumentParsingException ex = expectThrows( - DocumentParsingException.class, - DocumentParsingException.class, - () -> documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of( - randomSemanticTextInferenceResults( - fieldName, - randomModel(randomFrom(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)), - List.of("a b") - ) - ) - ) - ) - ) - ); - assertThat( - ex.getMessage(), - containsString( - Strings.format("Field [%s] is not registered as a [%s] field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) - ) - ); - } - - public void testMissingInferenceId() throws IOException { - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); - IllegalArgumentException ex = expectThrows( - DocumentParsingException.class, - IllegalArgumentException.class, - () -> documentMapper.parse( - source( - b -> b.startObject(InferenceMetadataFieldMapper.NAME) - .startObject("field") - .startObject(SemanticTextModelSettings.NAME) - .field(SemanticTextModelSettings.TASK_TYPE_FIELD.getPreferredName(), TaskType.SPARSE_EMBEDDING) - .endObject() - .endObject() - .endObject() - ) - ) - ); - assertThat(ex.getMessage(), containsString("required [inference_id] is missing")); - } - - public void testMissingModelSettings() throws IOException { - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); - DocumentParsingException ex = expectThrows( - DocumentParsingException.class, - DocumentParsingException.class, - () -> documentMapper.parse( - source( - b -> b.startObject(InferenceMetadataFieldMapper.NAME) - .startObject("field") - .field(InferenceMetadataFieldMapper.INFERENCE_ID, "my_id") - .endObject() - .endObject() - ) - ) - ); - assertThat(ex.getMessage(), containsString("Missing required [model_settings] for field [field] of type [semantic_text]")); - } - - public void testMissingTaskType() throws IOException { - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); - DocumentParsingException ex = expectThrows( - DocumentParsingException.class, - DocumentParsingException.class, - () -> documentMapper.parse( - source( - b -> b.startObject(InferenceMetadataFieldMapper.NAME) - .startObject("field") - .field(InferenceMetadataFieldMapper.INFERENCE_ID, "my_id") - .startObject(SemanticTextModelSettings.NAME) - .endObject() - .endObject() - .endObject() - ) - ) - ); - assertThat(ex.getCause().getMessage(), containsString(" Failed to parse [model_settings], required [task_type] is missing")); - } - - private static void addSemanticTextMapping(XContentBuilder mappingBuilder, String fieldName, String modelId) throws IOException { - mappingBuilder.startObject(fieldName); - mappingBuilder.field("type", SemanticTextFieldMapper.CONTENT_TYPE); - mappingBuilder.field("inference_id", modelId); - mappingBuilder.endObject(); - } - - public static ChunkedTextEmbeddingResults randomTextEmbeddings(Model model, List inputs) { - List chunks = new ArrayList<>(); - for (String input : inputs) { - double[] values = new double[model.getServiceSettings().dimensions()]; - for (int j = 0; j < values.length; j++) { - values[j] = randomDouble(); - } - chunks.add(new org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk(input, values)); - } - return new ChunkedTextEmbeddingResults(chunks); - } - - public static ChunkedSparseEmbeddingResults randomSparseEmbeddings(List inputs) { - List chunks = new ArrayList<>(); - for (String input : inputs) { - var tokens = new ArrayList(); - for (var token : input.split("\\s+")) { - tokens.add(new TextExpansionResults.WeightedToken(token, randomFloat())); - } - chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input, tokens)); - } - return new ChunkedSparseEmbeddingResults(chunks); - } - - private static SemanticTextInferenceResults randomSemanticTextInferenceResults( - String semanticTextFieldName, - Model model, - List chunks - ) { - ChunkedInferenceServiceResults chunkedResults = switch (model.getTaskType()) { - case TEXT_EMBEDDING -> randomTextEmbeddings(model, chunks); - case SPARSE_EMBEDDING -> randomSparseEmbeddings(chunks); - default -> throw new AssertionError("unkwnown task type: " + model.getTaskType().name()); - }; - return new SemanticTextInferenceResults(semanticTextFieldName, model, chunkedResults, chunks); - } - - private static void addSemanticTextInferenceResults( - XContentBuilder sourceBuilder, - List semanticTextInferenceResults - ) throws IOException { - addSemanticTextInferenceResults( - sourceBuilder, - semanticTextInferenceResults, - new SparseVectorSubfieldOptions(true, true, true), - true, - Map.of() - ); - } - - @SuppressWarnings("unchecked") - private static void addSemanticTextInferenceResults( - XContentBuilder sourceBuilder, - List semanticTextInferenceResults, - SparseVectorSubfieldOptions sparseVectorSubfieldOptions, - boolean includeTextSubfield, - Map extraSubfields - ) throws IOException { - Map inferenceResultsMap = new LinkedHashMap<>(); - for (SemanticTextInferenceResults semanticTextInferenceResult : semanticTextInferenceResults) { - InferenceMetadataFieldMapper.applyFieldInference( - inferenceResultsMap, - semanticTextInferenceResult.fieldName, - semanticTextInferenceResult.model, - semanticTextInferenceResult.results - ); - Map optionsMap = (Map) inferenceResultsMap.get(semanticTextInferenceResult.fieldName); - List> fieldResultList = (List>) optionsMap.get(CHUNKS); - for (var entry : fieldResultList) { - if (includeTextSubfield == false) { - entry.remove(INFERENCE_CHUNKS_TEXT); - } - if (sparseVectorSubfieldOptions.include == false) { - entry.remove(INFERENCE_CHUNKS_RESULTS); - } - entry.putAll(extraSubfields); - } - } - sourceBuilder.field(InferenceMetadataFieldMapper.NAME, inferenceResultsMap); - } - - static String randomFieldName(int numLevel) { - StringBuilder builder = new StringBuilder(); - for (int i = 0; i < numLevel; i++) { - if (i > 0) { - builder.append('.'); - } - builder.append(randomAlphaOfLengthBetween(5, 15)); - } - return builder.toString(); - } - - private static Model randomModel(TaskType taskType) { - String serviceName = randomAlphaOfLengthBetween(5, 10); - String inferenceId = randomAlphaOfLengthBetween(5, 10); - return new TestModel( - inferenceId, - taskType, - serviceName, - new TestModel.TestServiceSettings("my-model"), - new TestModel.TestTaskSettings(randomIntBetween(1, 100)), - new TestModel.TestSecretSettings(randomAlphaOfLength(10)) - ); - } - - private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLookup, String path, List tokens) { - NestedObjectMapper mapper = nestedLookup.getNestedMappers().get(path); - assertNotNull(mapper); - - BitSetProducer parentFilter = new QueryBitSetProducer(Queries.newNonNestedFilter(IndexVersion.current())); - BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder(); - for (String token : tokens) { - queryBuilder.add( - new BooleanClause(new TermQuery(new Term(path + "." + INFERENCE_CHUNKS_RESULTS, token)), BooleanClause.Occur.MUST) - ); - } - queryBuilder.add(new BooleanClause(mapper.nestedTypeFilter(), BooleanClause.Occur.FILTER)); - - return new ESToParentBlockJoinQuery(queryBuilder.build(), parentFilter, ScoreMode.Total, null); - } - - private static void assertValidChildDoc( - LuceneDocument childDoc, - LuceneDocument expectedParent, - Collection visitedChildDocs - ) { - assertEquals(expectedParent, childDoc.getParent()); - visitedChildDocs.add(new VisitedChildDocInfo(childDoc.getPath())); - } - - private static void assertChildLeafNestedDocument( - LeafNestedDocuments leaf, - int advanceToDoc, - int expectedRootDoc, - Set visitedNestedIdentities - ) throws IOException { - - assertNotNull(leaf.advance(advanceToDoc)); - assertEquals(advanceToDoc, leaf.doc()); - assertEquals(expectedRootDoc, leaf.rootDoc()); - assertNotNull(leaf.nestedIdentity()); - visitedNestedIdentities.add(leaf.nestedIdentity()); - } - - private static void assertSparseFeatures(LuceneDocument doc, String fieldName, int expectedCount) { - int count = 0; - for (IndexableField field : doc.getFields()) { - if (field instanceof FeatureField featureField) { - assertThat(featureField.name(), equalTo(fieldName)); - ++count; - } - } - assertThat(count, equalTo(expectedCount)); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 1b5311ac9effb..a6f0fa83eab37 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -7,32 +7,65 @@ package org.elasticsearch.xpack.inference.mapper; +import org.apache.lucene.document.FeatureField; import org.apache.lucene.index.IndexableField; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.QueryBitSetProducer; +import org.apache.lucene.search.join.ScoreMode; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.lucene.search.Queries; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.DocumentMapper; +import org.elasticsearch.index.mapper.DocumentParsingException; import org.elasticsearch.index.mapper.KeywordFieldMapper; +import org.elasticsearch.index.mapper.LuceneDocument; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.Mapper; -import org.elasticsearch.index.mapper.MapperBuilderContext; import org.elasticsearch.index.mapper.MapperParsingException; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.MapperTestCase; +import org.elasticsearch.index.mapper.NestedLookup; import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ParsedDocument; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; +import org.elasticsearch.index.search.ESToParentBlockJoinQuery; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.LeafNestedDocuments; +import org.elasticsearch.search.NestedDocuments; +import org.elasticsearch.search.SearchHit; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.InferencePlugin; import org.junit.AssumptionViolatedException; import java.io.IOException; import java.util.Collection; +import java.util.HashSet; import java.util.List; +import java.util.Set; import static java.util.Collections.singletonList; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.createSemanticFieldContext; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_TEXT_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_ID_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.MODEL_SETTINGS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getEmbeddingsFieldName; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomModel; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -55,7 +88,7 @@ protected String minimalIsInvalidRoutingPathErrorMessage(Mapper mapper) { @Override protected Object getSampleValueForDocument() { - return "value"; + return null; } @Override @@ -98,7 +131,7 @@ public void testDefaults() throws Exception { assertTrue(fields.isEmpty()); } - public void testInferenceIdNotPresent() throws IOException { + public void testInferenceIdNotPresent() { Exception e = expectThrows( MapperParsingException.class, () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text"))) @@ -112,6 +145,7 @@ public void testCannotBeUsedInMultiFields() { b.startObject("fields"); b.startObject("semantic"); b.field("type", "semantic_text"); + b.field("inference_id", "my_inference_id"); b.endObject(); b.endObject(); }))); @@ -136,7 +170,7 @@ public void testUpdatesToInferenceIdNotSupported() throws IOException { public void testUpdateModelSettings() throws IOException { for (int depth = 1; depth < 5; depth++) { - String fieldName = InferenceMetadataFieldMapperTests.randomFieldName(depth); + String fieldName = randomFieldName(depth); MapperService mapperService = createMapperService( mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()) ); @@ -157,7 +191,7 @@ public void testUpdateModelSettings() throws IOException { ) ) ); - assertThat(exc.getMessage(), containsString("Failed to parse [model_settings], required [task_type] is missing")); + assertThat(exc.getMessage(), containsString("Required [task_type]")); } { merge( @@ -220,12 +254,7 @@ public void testUpdateModelSettings() throws IOException { } static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) { - InferenceMetadataFieldMapper.SemanticTextMapperContext res = createSemanticFieldContext( - MapperBuilderContext.root(false, false), - mapperService.mappingLookup().getMapping().getRoot(), - fieldName.split("\\.") - ); - Mapper mapper = res.mapper(); + Mapper mapper = mapperService.mappingLookup().getMapper(fieldName); assertNotNull(mapper); assertThat(mapper, instanceOf(SemanticTextFieldMapper.class)); SemanticTextFieldMapper semanticFieldMapper = (SemanticTextFieldMapper) mapper; @@ -235,31 +264,257 @@ static void assertSemanticTextField(MapperService mapperService, String fieldNam assertThat(fieldType, instanceOf(SemanticTextFieldMapper.SemanticTextFieldType.class)); SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldType; assertTrue(semanticFieldMapper.fieldType() == semanticTextFieldType); - assertTrue(semanticFieldMapper.getSubMappers() == semanticTextFieldType.getSubMappers()); - assertTrue(semanticFieldMapper.getModelSettings() == semanticTextFieldType.getModelSettings()); - NestedObjectMapper nestedObjectMapper = mapperService.mappingLookup() + NestedObjectMapper chunksMapper = mapperService.mappingLookup() .nestedLookup() .getNestedMappers() - .get(fieldName + "." + InferenceMetadataFieldMapper.CHUNKS); - assertThat(nestedObjectMapper, equalTo(semanticFieldMapper.getSubMappers())); - Mapper textMapper = nestedObjectMapper.getMapper(InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT); + .get(getChunksFieldName(fieldName)); + assertThat(chunksMapper, equalTo(semanticFieldMapper.fieldType().getChunksField())); + Mapper textMapper = chunksMapper.getMapper(CHUNKED_TEXT_FIELD.getPreferredName()); assertNotNull(textMapper); assertThat(textMapper, instanceOf(KeywordFieldMapper.class)); KeywordFieldMapper textFieldMapper = (KeywordFieldMapper) textMapper; assertFalse(textFieldMapper.fieldType().isIndexed()); assertFalse(textFieldMapper.fieldType().hasDocValues()); if (expectedModelSettings) { - assertNotNull(semanticFieldMapper.getModelSettings()); - Mapper inferenceMapper = nestedObjectMapper.getMapper(InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS); + assertNotNull(semanticFieldMapper.fieldType().getModelSettings()); + Mapper inferenceMapper = chunksMapper.getMapper(CHUNKED_EMBEDDINGS_FIELD.getPreferredName()); assertNotNull(inferenceMapper); - switch (semanticFieldMapper.getModelSettings().taskType()) { + switch (semanticFieldMapper.fieldType().getModelSettings().taskType()) { case SPARSE_EMBEDDING -> assertThat(inferenceMapper, instanceOf(SparseVectorFieldMapper.class)); case TEXT_EMBEDDING -> assertThat(inferenceMapper, instanceOf(DenseVectorFieldMapper.class)); default -> throw new AssertionError("Invalid task type"); } } else { - assertNull(semanticFieldMapper.getModelSettings()); + assertNull(semanticFieldMapper.fieldType().getModelSettings()); + } + } + + public void testSuccessfulParse() throws IOException { + for (int depth = 1; depth < 4; depth++) { + final String fieldName1 = randomFieldName(depth); + final String fieldName2 = randomFieldName(depth + 1); + + Model model1 = randomModel(TaskType.SPARSE_EMBEDDING); + Model model2 = randomModel(TaskType.SPARSE_EMBEDDING); + XContentBuilder mapping = mapping(b -> { + addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId()); + addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId()); + }); + + MapperService mapperService = createMapperService(mapping); + SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName1, false); + SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName2, false); + DocumentMapper documentMapper = mapperService.documentMapper(); + ParsedDocument doc = documentMapper.parse( + source( + b -> addSemanticTextInferenceResults( + b, + List.of( + randomSemanticText(fieldName1, model1, List.of("a b", "c"), XContentType.JSON), + randomSemanticText(fieldName2, model2, List.of("d e f"), XContentType.JSON) + ) + ) + ) + ); + + List luceneDocs = doc.docs(); + assertEquals(4, luceneDocs.size()); + for (int i = 0; i < 3; i++) { + assertEquals(doc.rootDoc(), luceneDocs.get(i).getParent()); + } + // nested docs are in reversed order + assertSparseFeatures(luceneDocs.get(0), getEmbeddingsFieldName(fieldName1), 2); + assertSparseFeatures(luceneDocs.get(1), getEmbeddingsFieldName(fieldName1), 1); + assertSparseFeatures(luceneDocs.get(2), getEmbeddingsFieldName(fieldName2), 3); + assertEquals(doc.rootDoc(), luceneDocs.get(3)); + assertNull(luceneDocs.get(3).getParent()); + + withLuceneIndex(mapperService, iw -> iw.addDocuments(doc.docs()), reader -> { + NestedDocuments nested = new NestedDocuments( + mapperService.mappingLookup(), + QueryBitSetProducer::new, + IndexVersion.current() + ); + LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); + + Set visitedNestedIdentities = new HashSet<>(); + Set expectedVisitedNestedIdentities = Set.of( + new SearchHit.NestedIdentity(getChunksFieldName(fieldName1), 0, null), + new SearchHit.NestedIdentity(getChunksFieldName(fieldName1), 1, null), + new SearchHit.NestedIdentity(getChunksFieldName(fieldName2), 0, null) + ); + + assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); + assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); + + assertNull(leaf.advance(3)); + assertEquals(3, leaf.doc()); + assertEquals(3, leaf.rootDoc()); + assertNull(leaf.nestedIdentity()); + + IndexSearcher searcher = newSearcher(reader); + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a")), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a", "b")), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName2, List.of("d")), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName2, List.of("z")), + 10 + ); + assertEquals(0, topDocs.totalHits.value); + } + }); + } + } + + public void testMissingInferenceId() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + IllegalArgumentException ex = expectThrows( + DocumentParsingException.class, + IllegalArgumentException.class, + () -> documentMapper.parse( + source( + b -> b.startObject("field") + .startObject(INFERENCE_FIELD.getPreferredName()) + .field( + MODEL_SETTINGS_FIELD.getPreferredName(), + new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null) + ) + .field(CHUNKS_FIELD.getPreferredName(), List.of()) + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getCause().getMessage(), containsString("Required [inference_id]")); + } + + public void testMissingModelSettings() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + IllegalArgumentException ex = expectThrows( + DocumentParsingException.class, + IllegalArgumentException.class, + () -> documentMapper.parse( + source( + b -> b.startObject("field") + .startObject(INFERENCE_FIELD.getPreferredName()) + .field(INFERENCE_ID_FIELD.getPreferredName(), "my_id") + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getCause().getMessage(), containsString("Required [model_settings, chunks]")); + } + + public void testMissingTaskType() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + IllegalArgumentException ex = expectThrows( + DocumentParsingException.class, + IllegalArgumentException.class, + () -> documentMapper.parse( + source( + b -> b.startObject("field") + .startObject(INFERENCE_FIELD.getPreferredName()) + .field(INFERENCE_ID_FIELD.getPreferredName(), "my_id") + .startObject(MODEL_SETTINGS_FIELD.getPreferredName()) + .endObject() + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getCause().getMessage(), containsString("failed to parse field [model_settings]")); + } + + private static void addSemanticTextMapping(XContentBuilder mappingBuilder, String fieldName, String modelId) throws IOException { + mappingBuilder.startObject(fieldName); + mappingBuilder.field("type", SemanticTextFieldMapper.CONTENT_TYPE); + mappingBuilder.field("inference_id", modelId); + mappingBuilder.endObject(); + } + + private static void addSemanticTextInferenceResults(XContentBuilder sourceBuilder, List semanticTextInferenceResults) + throws IOException { + for (var field : semanticTextInferenceResults) { + sourceBuilder.field(field.fieldName()); + sourceBuilder.value(field); + } + } + + static String randomFieldName(int numLevel) { + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < numLevel; i++) { + if (i > 0) { + builder.append('.'); + } + builder.append(randomAlphaOfLengthBetween(5, 15)); + } + return builder.toString(); + } + + private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLookup, String fieldName, List tokens) { + NestedObjectMapper mapper = nestedLookup.getNestedMappers().get(getChunksFieldName(fieldName)); + assertNotNull(mapper); + + BitSetProducer parentFilter = new QueryBitSetProducer(Queries.newNonNestedFilter(IndexVersion.current())); + BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder(); + for (String token : tokens) { + queryBuilder.add( + new BooleanClause(new TermQuery(new Term(getEmbeddingsFieldName(fieldName), token)), BooleanClause.Occur.MUST) + ); + } + queryBuilder.add(new BooleanClause(mapper.nestedTypeFilter(), BooleanClause.Occur.FILTER)); + + return new ESToParentBlockJoinQuery(queryBuilder.build(), parentFilter, ScoreMode.Total, null); + } + + private static void assertChildLeafNestedDocument( + LeafNestedDocuments leaf, + int advanceToDoc, + int expectedRootDoc, + Set visitedNestedIdentities + ) throws IOException { + + assertNotNull(leaf.advance(advanceToDoc)); + assertEquals(advanceToDoc, leaf.doc()); + assertEquals(expectedRootDoc, leaf.rootDoc()); + assertNotNull(leaf.nestedIdentity()); + visitedNestedIdentities.add(leaf.nestedIdentity()); + } + + private static void assertSparseFeatures(LuceneDocument doc, String fieldName, int expectedCount) { + int count = 0; + for (IndexableField field : doc.getFields()) { + if (field instanceof FeatureField featureField) { + assertThat(featureField.name(), equalTo(fieldName)); + ++count; + } } + assertThat(count, equalTo(expectedCount)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java new file mode 100644 index 0000000000000..e6bdb7271163b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -0,0 +1,219 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.inference.model.TestModel; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; + +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; +import static org.hamcrest.Matchers.equalTo; + +public class SemanticTextFieldTests extends AbstractXContentTestCase { + private static final String NAME = "field"; + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return n -> n.endsWith(CHUNKED_EMBEDDINGS_FIELD.getPreferredName()); + } + + @Override + protected void assertEqualInstances(SemanticTextField expectedInstance, SemanticTextField newInstance) { + assertThat(newInstance.fieldName(), equalTo(expectedInstance.fieldName())); + assertThat(newInstance.raw(), equalTo(expectedInstance.raw())); + assertThat(newInstance.inference().modelSettings(), equalTo(expectedInstance.inference().modelSettings())); + assertThat(newInstance.inference().chunks().size(), equalTo(expectedInstance.inference().chunks().size())); + SemanticTextField.ModelSettings modelSettings = newInstance.inference().modelSettings(); + for (int i = 0; i < newInstance.inference().chunks().size(); i++) { + assertThat(newInstance.inference().chunks().get(i).text(), equalTo(expectedInstance.inference().chunks().get(i).text())); + switch (modelSettings.taskType()) { + case TEXT_EMBEDDING -> { + double[] expectedVector = parseDenseVector( + expectedInstance.inference().chunks().get(i).rawEmbeddings(), + modelSettings.dimensions(), + expectedInstance.contentType() + ); + double[] newVector = parseDenseVector( + newInstance.inference().chunks().get(i).rawEmbeddings(), + modelSettings.dimensions(), + newInstance.contentType() + ); + assertArrayEquals(expectedVector, newVector, 0f); + } + case SPARSE_EMBEDDING -> { + List expectedTokens = parseWeightedTokens( + expectedInstance.inference().chunks().get(i).rawEmbeddings(), + expectedInstance.contentType() + ); + List newTokens = parseWeightedTokens( + newInstance.inference().chunks().get(i).rawEmbeddings(), + newInstance.contentType() + ); + assertThat(newTokens, equalTo(expectedTokens)); + } + default -> throw new AssertionError("Invalid task type " + modelSettings.taskType()); + } + } + } + + @Override + protected SemanticTextField createTestInstance() { + List rawValues = randomList(1, 5, () -> randomAlphaOfLengthBetween(10, 20)); + return randomSemanticText( + NAME, + randomModel(randomFrom(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)), + rawValues, + randomFrom(XContentType.values()) + ); + } + + @Override + protected SemanticTextField doParseInstance(XContentParser parser) throws IOException { + return SemanticTextField.parse(parser, new Tuple<>(NAME, parser.contentType())); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + public static ChunkedTextEmbeddingResults randomTextEmbeddings(Model model, List inputs) { + List chunks = new ArrayList<>(); + for (String input : inputs) { + double[] values = new double[model.getServiceSettings().dimensions()]; + for (int j = 0; j < values.length; j++) { + values[j] = randomDouble(); + } + chunks.add(new org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk(input, values)); + } + return new ChunkedTextEmbeddingResults(chunks); + } + + public static ChunkedSparseEmbeddingResults randomSparseEmbeddings(List inputs) { + List chunks = new ArrayList<>(); + for (String input : inputs) { + var tokens = new ArrayList(); + for (var token : input.split("\\s+")) { + tokens.add(new TextExpansionResults.WeightedToken(token, randomFloat())); + } + chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input, tokens)); + } + return new ChunkedSparseEmbeddingResults(chunks); + } + + public static SemanticTextField randomSemanticText(String fieldName, Model model, List inputs, XContentType contentType) { + ChunkedInferenceServiceResults results = switch (model.getTaskType()) { + case TEXT_EMBEDDING -> randomTextEmbeddings(model, inputs); + case SPARSE_EMBEDDING -> randomSparseEmbeddings(inputs); + default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); + }; + return new SemanticTextField( + fieldName, + inputs, + new SemanticTextField.InferenceResult( + model.getInferenceEntityId(), + new SemanticTextField.ModelSettings(model), + toSemanticTextFieldChunks(fieldName, model.getInferenceEntityId(), List.of(results), contentType) + ), + contentType + ); + } + + public static Model randomModel(TaskType taskType) { + String serviceName = randomAlphaOfLengthBetween(5, 10); + String inferenceId = randomAlphaOfLengthBetween(5, 10); + return new TestModel( + inferenceId, + taskType, + serviceName, + new TestModel.TestServiceSettings("my-model"), + new TestModel.TestTaskSettings(randomIntBetween(1, 100)), + new TestModel.TestSecretSettings(randomAlphaOfLength(10)) + ); + } + + public static ChunkedInferenceServiceResults toChunkedResult(SemanticTextField field) { + switch (field.inference().modelSettings().taskType()) { + case SPARSE_EMBEDDING -> { + List chunks = new ArrayList<>(); + for (var chunk : field.inference().chunks()) { + var tokens = parseWeightedTokens(chunk.rawEmbeddings(), field.contentType()); + chunks.add(new ChunkedTextExpansionResults.ChunkedResult(chunk.text(), tokens)); + } + return new ChunkedSparseEmbeddingResults(chunks); + } + case TEXT_EMBEDDING -> { + List chunks = + new ArrayList<>(); + for (var chunk : field.inference().chunks()) { + double[] values = parseDenseVector( + chunk.rawEmbeddings(), + field.inference().modelSettings().dimensions(), + field.contentType() + ); + chunks.add( + new org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk( + chunk.text(), + values + ) + ); + } + return new ChunkedTextEmbeddingResults(chunks); + } + default -> throw new AssertionError("Invalid task_type: " + field.inference().modelSettings().taskType().name()); + } + } + + private static double[] parseDenseVector(BytesReference value, int numDims, XContentType contentType) { + try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, value, contentType)) { + parser.nextToken(); + assertThat(parser.currentToken(), equalTo(XContentParser.Token.START_ARRAY)); + double[] values = new double[numDims]; + for (int i = 0; i < numDims; i++) { + assertThat(parser.nextToken(), equalTo(XContentParser.Token.VALUE_NUMBER)); + values[i] = parser.doubleValue(); + } + assertThat(parser.nextToken(), equalTo(XContentParser.Token.END_ARRAY)); + return values; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static List parseWeightedTokens(BytesReference value, XContentType contentType) { + try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, value, contentType)) { + Map map = parser.map(); + List weightedTokens = new ArrayList<>(); + for (var entry : map.entrySet()) { + weightedTokens.add(new TextExpansionResults.WeightedToken(entry.getKey(), ((Number) entry.getValue()).floatValue())); + } + return weightedTokens; + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 0a07a88d230ef..1aa3f2752365c 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -80,16 +80,14 @@ setup: index: test-sparse-index id: doc_1 - - match: { _source.inference_field: "inference test" } - - match: { _source.another_inference_field: "another inference test" } + - match: { _source.inference_field.raw: "inference test" } + - exists: _source.inference_field.inference.chunks.0.embeddings + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.another_inference_field.raw: "another inference test" } + - exists: _source.another_inference_field.inference.chunks.0.embeddings + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.chunks.0.text: "inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - - exists: _source._inference.inference_field.chunks.0.inference - - exists: _source._inference.another_inference_field.chunks.0.inference - --- "text expansion documents do not create new mappings": - do: @@ -117,16 +115,14 @@ setup: index: test-dense-index id: doc_1 - - match: { _source.inference_field: "inference test" } - - match: { _source.another_inference_field: "another inference test" } + - match: { _source.inference_field.raw: "inference test" } + - exists: _source.inference_field.inference.chunks.0.embeddings + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.another_inference_field.raw: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } + - exists: _source.another_inference_field.inference.chunks.0.embeddings - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.chunks.0.text: "inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - - exists: _source._inference.inference_field.chunks.0.inference - - exists: _source._inference.another_inference_field.chunks.0.inference - --- "text embeddings documents do not create new mappings": @@ -155,8 +151,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._inference.inference_field.chunks.0.inference: inference_field_embedding } - - set: { _source._inference.another_inference_field.chunks.0.inference: another_inference_field_embedding } + - set: { _source.inference_field.inference.chunks.0.embeddings: inference_field_embedding } + - set: { _source.another_inference_field.inference.chunks.0.embeddings: another_inference_field_embedding } - do: update: @@ -171,17 +167,14 @@ setup: index: test-sparse-index id: doc_1 - - match: { _source.inference_field: "inference test" } - - match: { _source.another_inference_field: "another inference test" } + - match: { _source.inference_field.raw: "inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.inference_field.inference.chunks.0.embeddings: $inference_field_embedding } + - match: { _source.another_inference_field.raw: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.embeddings: $another_inference_field_embedding } - match: { _source.non_inference_field: "another non inference test" } - - length: { _source._inference: 2 } - - match: { _source._inference.inference_field.chunks.0.text: "inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } - - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } - --- "Updating semantic_text fields recalculates embeddings": - do: @@ -198,12 +191,11 @@ setup: index: test-sparse-index id: doc_1 - - match: { _source.inference_field: "inference test" } - - match: { _source.another_inference_field: "another inference test" } + - match: { _source.inference_field.raw: "inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.another_inference_field.raw: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - length: { _source._inference: 2 } - - match: { _source._inference.inference_field.chunks.0.text: "inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - do: bulk: @@ -217,12 +209,11 @@ setup: index: test-sparse-index id: doc_1 - - match: { _source.inference_field: "I am a test" } - - match: { _source.another_inference_field: "I am a teapot" } + - match: { _source.inference_field.raw: "I am a test" } + - match: { _source.inference_field.inference.chunks.0.text: "I am a test" } + - match: { _source.another_inference_field.raw: "I am a teapot" } + - match: { _source.another_inference_field.inference.chunks.0.text: "I am a teapot" } - match: { _source.non_inference_field: "non inference test" } - - length: { _source._inference: 2 } - - match: { _source._inference.inference_field.chunks.0.text: "I am a test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "I am a teapot" } - do: update: @@ -238,12 +229,11 @@ setup: index: test-sparse-index id: doc_1 - - match: { _source.inference_field: "updated inference test" } - - match: { _source.another_inference_field: "another updated inference test" } + - match: { _source.inference_field.raw: "updated inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "updated inference test" } + - match: { _source.another_inference_field.raw: "another updated inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another updated inference test" } - match: { _source.non_inference_field: "non inference test" } - - length: { _source._inference: 2 } - - match: { _source._inference.inference_field.chunks.0.text: "updated inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another updated inference test" } - do: bulk: @@ -257,12 +247,11 @@ setup: index: test-sparse-index id: doc_1 - - match: { _source.inference_field: "bulk inference test" } - - match: { _source.another_inference_field: "bulk updated inference test" } + - match: { _source.inference_field.raw: "bulk inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "bulk inference test" } + - match: { _source.another_inference_field.raw: "bulk updated inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "bulk updated inference test" } - match: { _source.non_inference_field: "non inference test" } - - length: { _source._inference: 2 } - - match: { _source._inference.inference_field.chunks.0.text: "bulk inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "bulk updated inference test" } --- "Reindex works for semantic_text fields": @@ -280,8 +269,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._inference.inference_field.chunks.0.inference: inference_field_embedding } - - set: { _source._inference.another_inference_field.chunks.0.inference: another_inference_field_embedding } + - set: { _source.inference_field.inference.chunks.0.embeddings: inference_field_embedding } + - set: { _source.another_inference_field.inference.chunks.0.embeddings: another_inference_field_embedding } - do: indices.refresh: { } @@ -314,17 +303,14 @@ setup: index: destination-index id: doc_1 - - match: { _source.inference_field: "inference test" } - - match: { _source.another_inference_field: "another inference test" } + - match: { _source.inference_field.raw: "inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "inference test" } + - match: { _source.inference_field.inference.chunks.0.embeddings: $inference_field_embedding } + - match: { _source.another_inference_field.raw: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.text: "another inference test" } + - match: { _source.another_inference_field.inference.chunks.0.embeddings: $another_inference_field_embedding } - match: { _source.non_inference_field: "non inference test" } - - length: { _source._inference: 2 } - - match: { _source._inference.inference_field.chunks.0.text: "inference test" } - - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } - - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } - --- "Fails for non-existent inference": - do: @@ -378,22 +364,6 @@ setup: - match: { items.0.update.status: 400 } - match: { items.0.update.error.reason: "Cannot apply update with a script on indices that contain [semantic_text] field(s)" } ---- -"Fails when providing inference results and there is no value for field": - - do: - catch: /The field \[inference_field\] is referenced in the \[_inference\] metadata field but has no value/ - index: - index: test-sparse-index - id: doc_1 - body: - _inference: - inference_field: - chunks: - - text: "inference test" - inference: - "hello": 0.123 - - --- "semantic_text copy_to calculate inference for source fields": - do: @@ -426,14 +396,13 @@ setup: index: test-copy-to-index id: doc_1 - - match: { _source.inference_field: "inference test" } - - length: { _source._inference.inference_field.chunks: 3 } - - exists: _source._inference.inference_field.chunks.0.inference - - exists: _source._inference.inference_field.chunks.0.text - - exists: _source._inference.inference_field.chunks.1.inference - - exists: _source._inference.inference_field.chunks.1.text - - exists: _source._inference.inference_field.chunks.2.inference - - exists: _source._inference.inference_field.chunks.2.text + - match: { _source.inference_field.raw: "inference test" } + - match: { _source.inference_field.inference.chunks.0.text: "another copy_to inference test" } + - exists: _source.inference_field.inference.chunks.0.embeddings + - match: { _source.inference_field.inference.chunks.1.text: "inference test" } + - exists: _source.inference_field.inference.chunks.1.embeddings + - match: { _source.inference_field.inference.chunks.2.text: "copy_to inference test" } + - exists: _source.inference_field.inference.chunks.2.embeddings --- diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml index 9dc109b3fb81d..27f233436b925 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -66,23 +66,3 @@ setup: id: doc_1 body: dense_field: "you know, for testing" - ---- -"Inference section contains unreferenced fields": - - do: - catch: /Field \[unknown_field\] is not registered as a \[semantic_text\] field type/ - index: - index: test-index - id: doc_1 - body: - non_inference_field: "you know, for testing" - _inference: - unknown_field: - inference_id: dense-inference-id - model_settings: - task_type: text_embedding - chunks: - - text: "inference test" - inference: [ 0.1, 0.2, 0.3, 0.4, 0.5 ] - - text: "another inference test" - inference: [ -0.1, -0.2, -0.3, -0.4, -0.5 ]