From d4e283dde6a7b4f93c1489ac7ff733100f864376 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Fri, 22 Mar 2024 17:31:47 +0000 Subject: [PATCH] [feature/semantic_text] Register semantic text sub fields in the mapping (#106560) This PR refactors the semantic text field mapper to register its sub fields in the mapping instead of re-creating them each time when parsing documents. It also fixes the generation of these fields in case the semantic text field is defined in an object field. Lastly this change adds a new section called model_settings in the field parameter that is updated by the field mapper when inference results are received from a bulk action. The model settings are available in the fields as soon as the first document with the inference field is ingested and they are used to validate that updates. They are used to ensure consistency between what's used in the bulk action and what's defined in the field mapping. --- .../xcontent/support/XContentMapValues.java | 2 +- .../index/mapper/FieldMapper.java | 8 +- .../elasticsearch/index/mapper/Mapping.java | 2 +- .../vectors/SparseVectorFieldMapper.java | 7 +- .../TestDenseInferenceServiceExtension.java | 2 +- .../xpack/inference/InferencePlugin.java | 4 +- .../ShardBulkInferenceActionFilter.java | 16 +- .../mapper/InferenceMetadataFieldMapper.java | 449 ++++++++++++++++++ .../mapper/InferenceResultFieldMapper.java | 372 --------------- .../mapper/SemanticTextFieldMapper.java | 210 +++++++- .../mapper/SemanticTextModelSettings.java | 136 ++++-- .../SemanticTextClusterMetadataTests.java | 4 +- .../ShardBulkInferenceActionFilterTests.java | 12 +- ...=> InferenceMetadataFieldMapperTests.java} | 392 +++++++++------ .../mapper/SemanticTextFieldMapperTests.java | 235 +++++++-- .../xpack/inference/model/TestModel.java | 11 + .../inference/10_semantic_text_inference.yml | 59 +-- .../20_semantic_text_field_mapper.yml | 97 +--- 18 files changed, 1267 insertions(+), 751 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/{InferenceResultFieldMapperTests.java => InferenceMetadataFieldMapperTests.java} (57%) diff --git a/server/src/main/java/org/elasticsearch/common/xcontent/support/XContentMapValues.java b/server/src/main/java/org/elasticsearch/common/xcontent/support/XContentMapValues.java index 805931550ad62..f527b4cd8d684 100644 --- a/server/src/main/java/org/elasticsearch/common/xcontent/support/XContentMapValues.java +++ b/server/src/main/java/org/elasticsearch/common/xcontent/support/XContentMapValues.java @@ -555,7 +555,7 @@ public static Map nodeMapValue(Object node, String desc) { if (node instanceof Map) { return (Map) node; } else { - throw new ElasticsearchParseException(desc + " should be a hash but was of type: " + node.getClass()); + throw new ElasticsearchParseException(desc + " should be a map but was of type: " + node.getClass()); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java index 71fd9edd49903..f9354025cab49 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java @@ -1176,7 +1176,7 @@ public static final class Conflicts { private final String mapperName; private final List conflicts = new ArrayList<>(); - Conflicts(String mapperName) { + public Conflicts(String mapperName) { this.mapperName = mapperName; } @@ -1188,7 +1188,11 @@ void addConflict(String parameter, String existing, String toMerge) { conflicts.add("Cannot update parameter [" + parameter + "] from [" + existing + "] to [" + toMerge + "]"); } - void check() { + public boolean hasConflicts() { + return conflicts.isEmpty() == false; + } + + public void check() { if (conflicts.isEmpty()) { return; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/Mapping.java b/server/src/main/java/org/elasticsearch/index/mapper/Mapping.java index 903e4e5da5b29..da184d6f7a45e 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/Mapping.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/Mapping.java @@ -76,7 +76,7 @@ public CompressedXContent toCompressedXContent() { /** * Returns the root object for the current mapping */ - RootObjectMapper getRoot() { + public RootObjectMapper getRoot() { return root; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java index 6532abed19044..58286d34dada1 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java @@ -171,9 +171,12 @@ public void parse(DocumentParserContext context) throws IOException { } String feature = null; + boolean origIsWithLeafObject = context.path().isWithinLeafObject(); try { // make sure that we don't expand dots in field names while parsing - context.path().setWithinLeafObject(true); + if (context.path().isWithinLeafObject() == false) { + context.path().setWithinLeafObject(true); + } for (Token token = context.parser().nextToken(); token != Token.END_OBJECT; token = context.parser().nextToken()) { if (token == Token.FIELD_NAME) { feature = context.parser().currentName(); @@ -207,7 +210,7 @@ public void parse(DocumentParserContext context) throws IOException { context.addToFieldNames(fieldType().name()); } } finally { - context.path().setWithinLeafObject(false); + context.path().setWithinLeafObject(origIsWithLeafObject); } } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 54fe6e01946b4..586850eb948d3 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -166,7 +166,7 @@ public static TestServiceSettings fromMap(Map map) { SimilarityMeasure similarity = null; String similarityStr = (String) map.remove("similarity"); if (similarityStr != null) { - similarity = SimilarityMeasure.valueOf(similarityStr); + similarity = SimilarityMeasure.fromString(similarityStr); } return new TestServiceSettings(model, dimensions, similarity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 2a9c300e12c13..3fcd9049ae803 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -55,7 +55,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; @@ -284,7 +284,7 @@ public Map getMappers() { @Override public Map getMetadataMappers() { - return Map.of(InferenceResultFieldMapper.NAME, InferenceResultFieldMapper.PARSER); + return Map.of(InferenceMetadataFieldMapper.NAME, InferenceMetadataFieldMapper.PARSER); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index fbf84762eb314..00dc195313a61 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -38,7 +38,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; -import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; @@ -50,7 +50,7 @@ /** * An {@link ActionFilter} that performs inference on {@link BulkShardRequest} asynchronously and stores the results in - * the individual {@link BulkItemRequest}. The results are then consumed by the {@link InferenceResultFieldMapper} + * the individual {@link BulkItemRequest}. The results are then consumed by the {@link InferenceMetadataFieldMapper} * in the subsequent {@link TransportShardBulkAction} downstream. */ public class ShardBulkInferenceActionFilter implements MappedActionFilter { @@ -267,10 +267,10 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons Map newDocMap = indexRequest.sourceAsMap(); Map inferenceMap = new LinkedHashMap<>(); // ignore the existing inference map if any - newDocMap.put(InferenceResultFieldMapper.NAME, inferenceMap); + newDocMap.put(InferenceMetadataFieldMapper.NAME, inferenceMap); for (FieldInferenceResponse fieldResponse : response.responses()) { try { - InferenceResultFieldMapper.applyFieldInference( + InferenceMetadataFieldMapper.applyFieldInference( inferenceMap, fieldResponse.field(), fieldResponse.model(), @@ -295,6 +295,7 @@ private Map> createFieldInferenceRequests(Bu continue; } final Map docMap = indexRequest.sourceAsMap(); + boolean hasInput = false; for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { String field = entry.getKey(); String inferenceId = entry.getValue().inferenceId(); @@ -315,6 +316,7 @@ private Map> createFieldInferenceRequests(Bu if (value instanceof String valueStr) { List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); + hasInput = true; } else { inferenceResults.get(item.id()).failures.add( new ElasticsearchStatusException( @@ -326,6 +328,12 @@ private Map> createFieldInferenceRequests(Bu ); } } + if (hasInput == false) { + // remove the existing _inference field (if present) since none of the content require inference. + if (docMap.remove(InferenceMetadataFieldMapper.NAME) != null) { + indexRequest.source(docMap); + } + } } return fieldRequestsMap; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java new file mode 100644 index 0000000000000..9eeb7a5407bc4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java @@ -0,0 +1,449 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.apache.lucene.search.Query; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.index.mapper.DocumentParserContext; +import org.elasticsearch.index.mapper.DocumentParsingException; +import org.elasticsearch.index.mapper.FieldMapper; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.MetadataFieldMapper; +import org.elasticsearch.index.mapper.NestedObjectMapper; +import org.elasticsearch.index.mapper.ObjectMapper; +import org.elasticsearch.index.mapper.SourceLoader; +import org.elasticsearch.index.mapper.SourceValueFetcher; +import org.elasticsearch.index.mapper.TextSearchInfo; +import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.DeprecationHandler; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.XContentLocation; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.support.MapXContentParser; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * A mapper for the {@code _inference} field. + *
+ *
+ * This mapper works in tandem with {@link SemanticTextFieldMapper semantic_text} fields to index inference results. + * The inference results for {@code semantic_text} fields are written to {@code _source} by an upstream process like so: + *
+ *
+ *
+ * {
+ *     "_source": {
+ *         "my_semantic_text_field": "these are not the droids you're looking for",
+ *         "_inference": {
+ *             "my_semantic_text_field": {
+ *                  "inference_id": "my_inference_id",
+ *                  "model_settings": {
+ *                      "task_type": "SPARSE_EMBEDDING"
+ *                  },
+ *                  "chunks" [
+ *                      {
+ *                          "inference": {
+ *                              "lucas": 0.05212344,
+ *                              "ty": 0.041213956,
+ *                              "dragon": 0.50991,
+ *                              "type": 0.23241979,
+ *                              "dr": 1.9312073,
+ *                              "##o": 0.2797593
+ *                          },
+ *                          "text": "these are not the droids you're looking for"
+ *                      }
+ *                  ]
+ *              }
+ *          }
+ *      }
+ * }
+ * 
+ * + * This mapper parses the contents of the {@code _inference} field and indexes it as if the mapping were configured like so: + *
+ *
+ *
+ * {
+ *     "mappings": {
+ *         "properties": {
+ *             "my_semantic_field": {
+ *                 "chunks": {
+ *                      "type": "nested",
+ *                      "properties": {
+ *                          "embedding": {
+ *                              "type": "sparse_vector|dense_vector"
+ *                          },
+ *                          "text": {
+ *                              "type": "keyword",
+ *                              "index": false,
+ *                              "doc_values": false
+ *                          }
+ *                     }
+ *                 }
+ *             }
+ *         }
+ *     }
+ * }
+ * 
+ */ +public class InferenceMetadataFieldMapper extends MetadataFieldMapper { + public static final String NAME = "_inference"; + public static final String CONTENT_TYPE = "_inference"; + + public static final String INFERENCE_ID = "inference_id"; + public static final String CHUNKS = "chunks"; + public static final String INFERENCE_CHUNKS_RESULTS = "inference"; + public static final String INFERENCE_CHUNKS_TEXT = "text"; + + public static final TypeParser PARSER = new FixedTypeParser(c -> new InferenceMetadataFieldMapper()); + + private static final Logger logger = LogManager.getLogger(InferenceMetadataFieldMapper.class); + + private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); + + static class SemanticTextInferenceFieldType extends MappedFieldType { + private static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); + + SemanticTextInferenceFieldType() { + super(NAME, true, false, false, TextSearchInfo.NONE, Collections.emptyMap()); + } + + @Override + public String typeName() { + return CONTENT_TYPE; + } + + @Override + public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + return SourceValueFetcher.identity(name(), context, format); + } + + @Override + public Query termQuery(Object value, SearchExecutionContext context) { + return null; + } + } + + public InferenceMetadataFieldMapper() { + super(SemanticTextInferenceFieldType.INSTANCE); + } + + @Override + protected void parseCreateField(DocumentParserContext context) throws IOException { + XContentParser parser = context.parser(); + failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.START_OBJECT); + boolean origWithLeafObject = context.path().isWithinLeafObject(); + try { + // make sure that we don't expand dots in field names while parsing + context.path().setWithinLeafObject(true); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.FIELD_NAME); + parseSingleField(context); + } + } finally { + context.path().setWithinLeafObject(origWithLeafObject); + } + } + + private NestedObjectMapper updateSemanticTextFieldMapper( + DocumentParserContext docContext, + SemanticTextMapperContext semanticFieldContext, + String newInferenceId, + SemanticTextModelSettings newModelSettings, + XContentLocation xContentLocation + ) { + final String fullFieldName = semanticFieldContext.mapper.fieldType().name(); + final String inferenceId = semanticFieldContext.mapper.fieldType().getInferenceId(); + if (newInferenceId.equals(inferenceId) == false) { + throw new DocumentParsingException( + xContentLocation, + Strings.format( + "The configured %s [%s] for field [%s] doesn't match the %s [%s] reported in the document.", + INFERENCE_ID, + inferenceId, + fullFieldName, + INFERENCE_ID, + newInferenceId + ) + ); + } + if (newModelSettings.taskType() == TaskType.TEXT_EMBEDDING && newModelSettings.dimensions() == null) { + throw new DocumentParsingException( + xContentLocation, + "Model settings for field [" + fullFieldName + "] must contain dimensions" + ); + } + if (semanticFieldContext.mapper.getModelSettings() == null) { + SemanticTextFieldMapper newMapper = new SemanticTextFieldMapper.Builder( + semanticFieldContext.mapper.simpleName(), + docContext.indexSettings().getIndexVersionCreated() + ).setInferenceId(newInferenceId).setModelSettings(newModelSettings).build(semanticFieldContext.context); + docContext.addDynamicMapper(newMapper); + return newMapper.getSubMappers(); + } else { + SemanticTextFieldMapper.Conflicts conflicts = new Conflicts(fullFieldName); + SemanticTextFieldMapper.canMergeModelSettings(semanticFieldContext.mapper.getModelSettings(), newModelSettings, conflicts); + try { + conflicts.check(); + } catch (Exception exc) { + throw new DocumentParsingException(xContentLocation, "Incompatible model_settings", exc); + } + } + return semanticFieldContext.mapper.getSubMappers(); + } + + private void parseSingleField(DocumentParserContext context) throws IOException { + XContentParser parser = context.parser(); + String fieldName = parser.currentName(); + SemanticTextMapperContext builderContext = createSemanticFieldContext(context, fieldName); + if (builderContext == null) { + throw new DocumentParsingException( + parser.getTokenLocation(), + Strings.format("Field [%s] is not registered as a [%s] field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) + ); + } + parser.nextToken(); + failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.START_OBJECT); + + // record the location of the inference field in the original source + XContentLocation xContentLocation = parser.getTokenLocation(); + // parse eagerly to extract the inference id and the model settings first + Map map = parser.mapOrdered(); + + // inference_id + Object inferenceIdObj = map.remove(INFERENCE_ID); + final String inferenceId = XContentMapValues.nodeStringValue(inferenceIdObj, null); + if (inferenceId == null) { + throw new IllegalArgumentException("required [" + INFERENCE_ID + "] is missing"); + } + + // model_settings + Object modelSettingsObj = map.remove(SemanticTextModelSettings.NAME); + if (modelSettingsObj == null) { + throw new DocumentParsingException( + parser.getTokenLocation(), + Strings.format( + "Missing required [%s] for field [%s] of type [%s]", + SemanticTextModelSettings.NAME, + fieldName, + SemanticTextFieldMapper.CONTENT_TYPE + ) + ); + } + final SemanticTextModelSettings modelSettings; + try { + modelSettings = SemanticTextModelSettings.fromMap(modelSettingsObj); + } catch (Exception exc) { + throw new DocumentParsingException( + xContentLocation, + Strings.format( + "Error parsing [%s] for field [%s] of type [%s]", + SemanticTextModelSettings.NAME, + fieldName, + SemanticTextFieldMapper.CONTENT_TYPE + ), + exc + ); + } + + var nestedObjectMapper = updateSemanticTextFieldMapper(context, builderContext, inferenceId, modelSettings, xContentLocation); + + // we know the model settings, so we can (re) parse the results array now + XContentParser subParser = new MapXContentParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + map, + XContentType.JSON + ); + DocumentParserContext mapContext = context.switchParser(subParser); + parseFieldInference(xContentLocation, subParser, mapContext, nestedObjectMapper); + } + + private void parseFieldInference( + XContentLocation xContentLocation, + XContentParser parser, + DocumentParserContext context, + NestedObjectMapper nestedMapper + ) throws IOException { + parser.nextToken(); + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_OBJECT); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + switch (parser.currentName()) { + case CHUNKS -> parseChunks(xContentLocation, parser, context, nestedMapper); + default -> throw new DocumentParsingException(xContentLocation, "Unknown field name " + parser.currentName()); + } + } + } + + private void parseChunks( + XContentLocation xContentLocation, + XContentParser parser, + DocumentParserContext context, + NestedObjectMapper nestedMapper + ) throws IOException { + parser.nextToken(); + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_ARRAY); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_ARRAY; token = parser.nextToken()) { + DocumentParserContext subContext = context.createNestedContext(nestedMapper); + parseResultsObject(xContentLocation, parser, subContext, nestedMapper); + } + } + + private void parseResultsObject( + XContentLocation xContentLocation, + XContentParser parser, + DocumentParserContext context, + NestedObjectMapper nestedMapper + ) throws IOException { + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_OBJECT); + Set visited = new HashSet<>(); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.FIELD_NAME); + visited.add(parser.currentName()); + FieldMapper fieldMapper = (FieldMapper) nestedMapper.getMapper(parser.currentName()); + if (fieldMapper == null) { + if (REQUIRED_SUBFIELDS.contains(parser.currentName())) { + throw new DocumentParsingException( + xContentLocation, + "Missing sub-fields definition for [" + parser.currentName() + "]" + ); + } else { + logger.debug("Skipping indexing of unrecognized field name [" + parser.currentName() + "]"); + advancePastCurrentFieldName(xContentLocation, parser); + continue; + } + } + parser.nextToken(); + fieldMapper.parse(context); + } + if (visited.containsAll(REQUIRED_SUBFIELDS) == false) { + Set missingSubfields = REQUIRED_SUBFIELDS.stream() + .filter(s -> visited.contains(s) == false) + .collect(Collectors.toSet()); + throw new DocumentParsingException(xContentLocation, "Missing required subfields: " + missingSubfields); + } + } + + private static void failIfTokenIsNot(XContentLocation xContentLocation, XContentParser parser, XContentParser.Token expected) { + if (parser.currentToken() != expected) { + throw new DocumentParsingException(xContentLocation, "Expected a " + expected.toString() + ", got " + parser.currentToken()); + } + } + + private static void advancePastCurrentFieldName(XContentLocation xContentLocation, XContentParser parser) throws IOException { + assert parser.currentToken() == XContentParser.Token.FIELD_NAME; + XContentParser.Token token = parser.nextToken(); + if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) { + parser.skipChildren(); + } else if (token.isValue() == false && token != XContentParser.Token.VALUE_NULL) { + throw new DocumentParsingException(xContentLocation, "Expected a START_* or VALUE_*, got " + token); + } + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { + return SourceLoader.SyntheticFieldLoader.NOTHING; + } + + public static void applyFieldInference( + Map inferenceMap, + String field, + Model model, + ChunkedInferenceServiceResults results + ) throws ElasticsearchException { + List> chunks = new ArrayList<>(); + if (results instanceof ChunkedSparseEmbeddingResults textExpansionResults) { + for (var chunk : textExpansionResults.getChunkedResults()) { + chunks.add(chunk.asMap()); + } + } else if (results instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { + for (var chunk : textEmbeddingResults.getChunks()) { + chunks.add(chunk.asMap()); + } + } else { + throw new ElasticsearchStatusException( + "Invalid inference results format for field [{}] with inference id [{}], got {}", + RestStatus.BAD_REQUEST, + field, + model.getInferenceEntityId(), + results.getWriteableName() + ); + } + Map fieldMap = new LinkedHashMap<>(); + fieldMap.put(INFERENCE_ID, model.getInferenceEntityId()); + fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); + fieldMap.put(CHUNKS, chunks); + inferenceMap.put(field, fieldMap); + } + + record SemanticTextMapperContext(MapperBuilderContext context, SemanticTextFieldMapper mapper) {} + + /** + * Returns the {@link SemanticTextFieldMapper} associated with the provided {@code fullName} + * and the {@link MapperBuilderContext} that was used to build it. + * If the field is not found or is of the wrong type, this method returns {@code null}. + */ + static SemanticTextMapperContext createSemanticFieldContext(DocumentParserContext docContext, String fullName) { + ObjectMapper rootMapper = docContext.mappingLookup().getMapping().getRoot(); + return createSemanticFieldContext(MapperBuilderContext.root(false, false), rootMapper, fullName.split("\\.")); + } + + static SemanticTextMapperContext createSemanticFieldContext( + MapperBuilderContext mapperContext, + ObjectMapper objectMapper, + String[] paths + ) { + Mapper mapper = objectMapper.getMapper(paths[0]); + if (mapper instanceof ObjectMapper newObjectMapper) { + mapperContext = mapperContext.createChildContext(paths[0], ObjectMapper.Dynamic.FALSE); + return createSemanticFieldContext(mapperContext, newObjectMapper, Arrays.copyOfRange(paths, 1, paths.length)); + } else if (mapper instanceof SemanticTextFieldMapper semanticMapper) { + return new SemanticTextMapperContext(mapperContext, semanticMapper); + } else { + if (mapper == null || paths.length == 1) { + return null; + } + // check if the semantic field is defined within a multi-field + Mapper fieldMapper = objectMapper.getMapper(String.join(".", Arrays.asList(paths))); + if (fieldMapper instanceof SemanticTextFieldMapper semanticMapper) { + return new SemanticTextMapperContext(mapperContext, semanticMapper); + } + } + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java deleted file mode 100644 index 2ede5419ab74e..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java +++ /dev/null @@ -1,372 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.mapper; - -import org.apache.lucene.search.Query; -import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.common.Strings; -import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.mapper.DocumentParserContext; -import org.elasticsearch.index.mapper.DocumentParsingException; -import org.elasticsearch.index.mapper.FieldMapper; -import org.elasticsearch.index.mapper.MappedFieldType; -import org.elasticsearch.index.mapper.Mapper; -import org.elasticsearch.index.mapper.MapperBuilderContext; -import org.elasticsearch.index.mapper.MetadataFieldMapper; -import org.elasticsearch.index.mapper.NestedObjectMapper; -import org.elasticsearch.index.mapper.ObjectMapper; -import org.elasticsearch.index.mapper.SourceLoader; -import org.elasticsearch.index.mapper.SourceValueFetcher; -import org.elasticsearch.index.mapper.TextFieldMapper; -import org.elasticsearch.index.mapper.TextSearchInfo; -import org.elasticsearch.index.mapper.ValueFetcher; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; -import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * A mapper for the {@code _semantic_text_inference} field. - *
- *
- * This mapper works in tandem with {@link SemanticTextFieldMapper semantic_text} fields to index inference results. - * The inference results for {@code semantic_text} fields are written to {@code _source} by an upstream process like so: - *
- *
- *
- * {
- *     "_source": {
- *         "my_semantic_text_field": "these are not the droids you're looking for",
- *         "_inference": {
- *             "my_semantic_text_field": [
- *                 {
- *                     "sparse_embedding": {
- *                          "lucas": 0.05212344,
- *                          "ty": 0.041213956,
- *                          "dragon": 0.50991,
- *                          "type": 0.23241979,
- *                          "dr": 1.9312073,
- *                          "##o": 0.2797593
- *                     },
- *                     "text": "these are not the droids you're looking for"
- *                 }
- *             ]
- *         }
- *     }
- * }
- * 
- * - * This mapper parses the contents of the {@code _semantic_text_inference} field and indexes it as if the mapping were configured like so: - *
- *
- *
- * {
- *     "mappings": {
- *         "properties": {
- *             "my_semantic_text_field": {
- *                 "type": "nested",
- *                 "properties": {
- *                     "sparse_embedding": {
- *                         "type": "sparse_vector"
- *                     },
- *                     "text": {
- *                         "type": "text",
- *                         "index": false
- *                     }
- *                 }
- *             }
- *         }
- *     }
- * }
- * 
- */ -public class InferenceResultFieldMapper extends MetadataFieldMapper { - public static final String NAME = "_inference"; - public static final String CONTENT_TYPE = "_inference"; - - public static final String RESULTS = "results"; - public static final String INFERENCE_CHUNKS_RESULTS = "inference"; - public static final String INFERENCE_CHUNKS_TEXT = "text"; - - public static final TypeParser PARSER = new FixedTypeParser(c -> new InferenceResultFieldMapper()); - - private static final Logger logger = LogManager.getLogger(InferenceResultFieldMapper.class); - - private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); - - static class SemanticTextInferenceFieldType extends MappedFieldType { - private static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); - - SemanticTextInferenceFieldType() { - super(NAME, true, false, false, TextSearchInfo.NONE, Collections.emptyMap()); - } - - @Override - public String typeName() { - return CONTENT_TYPE; - } - - @Override - public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.identity(name(), context, format); - } - - @Override - public Query termQuery(Object value, SearchExecutionContext context) { - return null; - } - } - - public InferenceResultFieldMapper() { - super(SemanticTextInferenceFieldType.INSTANCE); - } - - @Override - protected void parseCreateField(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - - parseAllFields(context); - } - - private static void parseAllFields(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - parseSingleField(context, mapperBuilderContext); - } - } - - private static void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) throws IOException { - - XContentParser parser = context.parser(); - String fieldName = parser.currentName(); - Mapper mapper = context.getMapper(fieldName); - if (mapper == null || SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) { - throw new DocumentParsingException( - parser.getTokenLocation(), - Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) - ); - } - parser.nextToken(); - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - parser.nextToken(); - SemanticTextModelSettings modelSettings = SemanticTextModelSettings.parse(parser); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - String currentName = parser.currentName(); - if (RESULTS.equals(currentName)) { - NestedObjectMapper nestedObjectMapper = createInferenceResultsObjectMapper( - context, - mapperBuilderContext, - fieldName, - modelSettings - ); - parseFieldInferenceChunks(context, mapperBuilderContext, fieldName, modelSettings, nestedObjectMapper); - } else { - logger.debug("Skipping unrecognized field name [" + currentName + "]"); - advancePastCurrentFieldName(parser); - } - } - } - - private static void parseFieldInferenceChunks( - DocumentParserContext context, - MapperBuilderContext mapperBuilderContext, - String fieldName, - SemanticTextModelSettings modelSettings, - NestedObjectMapper nestedObjectMapper - ) throws IOException { - XContentParser parser = context.parser(); - - parser.nextToken(); - failIfTokenIsNot(parser, XContentParser.Token.START_ARRAY); - - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_ARRAY; token = parser.nextToken()) { - DocumentParserContext nestedContext = context.createNestedContext(nestedObjectMapper); - parseFieldInferenceChunkElement(nestedContext, nestedObjectMapper, modelSettings); - } - } - - private static void parseFieldInferenceChunkElement( - DocumentParserContext context, - ObjectMapper objectMapper, - SemanticTextModelSettings modelSettings - ) throws IOException { - XContentParser parser = context.parser(); - DocumentParserContext childContext = context.createChildContext(objectMapper); - - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - - Set visitedSubfields = new HashSet<>(); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - String currentName = parser.currentName(); - visitedSubfields.add(currentName); - - Mapper childMapper = objectMapper.getMapper(currentName); - if (childMapper == null) { - logger.debug("Skipping indexing of unrecognized field name [" + currentName + "]"); - advancePastCurrentFieldName(parser); - continue; - } - - if (childMapper instanceof FieldMapper fieldMapper) { - parser.nextToken(); - fieldMapper.parse(childContext); - } else { - // This should never happen, but fail parsing if it does so that it's not a silent failure - throw new DocumentParsingException( - parser.getTokenLocation(), - Strings.format("Unhandled mapper type [%s] for field [%s]", childMapper.getClass(), currentName) - ); - } - } - - if (visitedSubfields.containsAll(REQUIRED_SUBFIELDS) == false) { - Set missingSubfields = REQUIRED_SUBFIELDS.stream() - .filter(s -> visitedSubfields.contains(s) == false) - .collect(Collectors.toSet()); - throw new DocumentParsingException(parser.getTokenLocation(), "Missing required subfields: " + missingSubfields); - } - } - - private static NestedObjectMapper createInferenceResultsObjectMapper( - DocumentParserContext context, - MapperBuilderContext mapperBuilderContext, - String fieldName, - SemanticTextModelSettings modelSettings - ) { - IndexVersion indexVersionCreated = context.indexSettings().getIndexVersionCreated(); - FieldMapper.Builder resultsBuilder; - if (modelSettings.taskType() == TaskType.SPARSE_EMBEDDING) { - resultsBuilder = new SparseVectorFieldMapper.Builder(INFERENCE_CHUNKS_RESULTS); - } else if (modelSettings.taskType() == TaskType.TEXT_EMBEDDING) { - DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( - INFERENCE_CHUNKS_RESULTS, - indexVersionCreated - ); - SimilarityMeasure similarity = modelSettings.similarity(); - if (similarity != null) { - switch (similarity) { - case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); - case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); - default -> throw new IllegalArgumentException( - "Unknown similarity measure for field [" + fieldName + "] in model settings: " + similarity - ); - } - } - Integer dimensions = modelSettings.dimensions(); - if (dimensions == null) { - throw new IllegalArgumentException("Model settings for field [" + fieldName + "] must contain dimensions"); - } - denseVectorMapperBuilder.dimensions(dimensions); - resultsBuilder = denseVectorMapperBuilder; - } else { - throw new IllegalArgumentException("Unknown task type for field [" + fieldName + "]: " + modelSettings.taskType()); - } - - TextFieldMapper.Builder textMapperBuilder = new TextFieldMapper.Builder( - INFERENCE_CHUNKS_TEXT, - indexVersionCreated, - context.indexAnalyzers() - ).index(false).store(false); - - NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder( - fieldName, - context.indexSettings().getIndexVersionCreated() - ); - nestedBuilder.add(resultsBuilder).add(textMapperBuilder); - - return nestedBuilder.build(mapperBuilderContext); - } - - private static void advancePastCurrentFieldName(XContentParser parser) throws IOException { - assert parser.currentToken() == XContentParser.Token.FIELD_NAME; - - XContentParser.Token token = parser.nextToken(); - if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) { - parser.skipChildren(); - } else if (token.isValue() == false && token != XContentParser.Token.VALUE_NULL) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_* or VALUE_*, got " + token); - } - } - - private static void failIfTokenIsNot(XContentParser parser, XContentParser.Token expected) { - if (parser.currentToken() != expected) { - throw new DocumentParsingException( - parser.getTokenLocation(), - "Expected a " + expected.toString() + ", got " + parser.currentToken() - ); - } - } - - @Override - protected String contentType() { - return CONTENT_TYPE; - } - - @Override - public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { - return SourceLoader.SyntheticFieldLoader.NOTHING; - } - - public static void applyFieldInference( - Map inferenceMap, - String field, - Model model, - ChunkedInferenceServiceResults results - ) throws ElasticsearchException { - List> chunks = new ArrayList<>(); - if (results instanceof ChunkedSparseEmbeddingResults textExpansionResults) { - for (var chunk : textExpansionResults.getChunkedResults()) { - chunks.add(chunk.asMap()); - } - } else if (results instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { - for (var chunk : textEmbeddingResults.getChunks()) { - chunks.add(chunk.asMap()); - } - } else { - throw new ElasticsearchStatusException( - "Invalid inference results format for field [{}] with inference id [{}], got {}", - RestStatus.BAD_REQUEST, - field, - model.getInferenceEntityId(), - results.getWriteableName() - ); - } - Map fieldMap = new LinkedHashMap<>(); - fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); - fieldMap.put(InferenceResultFieldMapper.RESULTS, chunks); - inferenceMap.put(field, fieldMap); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 83272a10f98d4..2445d5c8751a5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -9,30 +9,50 @@ import org.apache.lucene.search.Query; import org.elasticsearch.common.Strings; +import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.FieldMapper; import org.elasticsearch.index.mapper.InferenceModelFieldType; +import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.NestedObjectMapper; +import org.elasticsearch.index.mapper.ObjectMapper; import org.elasticsearch.index.mapper.SimpleMappedFieldType; +import org.elasticsearch.index.mapper.SourceLoader; import org.elasticsearch.index.mapper.SourceValueFetcher; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.CHUNKS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT; /** - * A {@link FieldMapper} for semantic text fields. These fields have a model id reference, that is used for performing inference - * at ingestion and query time. - * For now, it is compatible with text expansion models only, but will be extended to support dense vector models as well. + * A {@link FieldMapper} for semantic text fields. + * These fields have a reference id reference, that is used for performing inference at ingestion and query time. * This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will - * be indexed using {@link InferenceResultFieldMapper}. + * be indexed using {@link InferenceMetadataFieldMapper}. */ public class SemanticTextFieldMapper extends FieldMapper { + private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class); public static final String CONTENT_TYPE = "semantic_text"; @@ -40,15 +60,39 @@ private static SemanticTextFieldMapper toType(FieldMapper in) { return (SemanticTextFieldMapper) in; } - public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n), notInMultiFields(CONTENT_TYPE)); + public static final TypeParser PARSER = new TypeParser( + (n, c) -> new Builder(n, c.indexVersionCreated()), + notInMultiFields(CONTENT_TYPE) + ); + + private final IndexVersion indexVersionCreated; + private final SemanticTextModelSettings modelSettings; + private final NestedObjectMapper subMappers; - private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, CopyTo copyTo) { + private SemanticTextFieldMapper( + String simpleName, + MappedFieldType mappedFieldType, + CopyTo copyTo, + IndexVersion indexVersionCreated, + SemanticTextModelSettings modelSettings, + NestedObjectMapper subMappers + ) { super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); + this.indexVersionCreated = indexVersionCreated; + this.modelSettings = modelSettings; + this.subMappers = subMappers; + } + + @Override + public Iterator iterator() { + List subIterators = new ArrayList<>(); + subIterators.add(subMappers); + return subIterators.iterator(); } @Override public FieldMapper.Builder getMergeBuilder() { - return new Builder(simpleName()).init(this); + return new Builder(simpleName(), indexVersionCreated).init(this); } @Override @@ -67,39 +111,100 @@ public SemanticTextFieldType fieldType() { return (SemanticTextFieldType) super.fieldType(); } + public SemanticTextModelSettings getModelSettings() { + return modelSettings; + } + + public NestedObjectMapper getSubMappers() { + return subMappers; + } + public static class Builder extends FieldMapper.Builder { + private final IndexVersion indexVersionCreated; - private final Parameter modelId = Parameter.stringParam("model_id", false, m -> toType(m).fieldType().modelId, null) - .addValidator(v -> { - if (Strings.isEmpty(v)) { - throw new IllegalArgumentException("field [model_id] must be specified"); - } - }); + private final Parameter inferenceId = Parameter.stringParam( + "inference_id", + false, + m -> toType(m).fieldType().inferenceId, + null + ).addValidator(v -> { + if (Strings.isEmpty(v)) { + throw new IllegalArgumentException("field [inference_id] must be specified"); + } + }); + private final Parameter modelSettings = new Parameter<>( + "model_settings", + true, + () -> null, + (n, c, o) -> SemanticTextModelSettings.fromMap(o), + mapper -> ((SemanticTextFieldMapper) mapper).modelSettings, + XContentBuilder::field, + (m) -> m == null ? "null" : Strings.toString(m) + ).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeModelSettings); private final Parameter> meta = Parameter.metaParam(); - public Builder(String name) { + public Builder(String name, IndexVersion indexVersionCreated) { super(name); + this.indexVersionCreated = indexVersionCreated; + } + + public Builder setInferenceId(String id) { + this.inferenceId.setValue(id); + return this; + } + + public Builder setModelSettings(SemanticTextModelSettings value) { + this.modelSettings.setValue(value); + return this; } @Override protected Parameter[] getParameters() { - return new Parameter[] { modelId, meta }; + return new Parameter[] { inferenceId, modelSettings, meta }; } @Override public SemanticTextFieldMapper build(MapperBuilderContext context) { - return new SemanticTextFieldMapper(name(), new SemanticTextFieldType(name(), modelId.getValue(), meta.getValue()), copyTo); + final String fullName = context.buildFullName(name()); + NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder(CHUNKS, indexVersionCreated); + nestedBuilder.dynamic(ObjectMapper.Dynamic.FALSE); + KeywordFieldMapper.Builder textMapperBuilder = new KeywordFieldMapper.Builder(INFERENCE_CHUNKS_TEXT, indexVersionCreated) + .indexed(false) + .docValues(false); + if (modelSettings.get() != null) { + nestedBuilder.add(createInferenceMapperBuilder(INFERENCE_CHUNKS_RESULTS, modelSettings.get(), indexVersionCreated)); + } + nestedBuilder.add(textMapperBuilder); + var childContext = context.createChildContext(name(), ObjectMapper.Dynamic.FALSE); + var subMappers = nestedBuilder.build(childContext); + return new SemanticTextFieldMapper( + name(), + new SemanticTextFieldType(fullName, inferenceId.getValue(), modelSettings.getValue(), subMappers, meta.getValue()), + copyTo, + indexVersionCreated, + modelSettings.getValue(), + subMappers + ); } } public static class SemanticTextFieldType extends SimpleMappedFieldType implements InferenceModelFieldType { + private final String inferenceId; + private final SemanticTextModelSettings modelSettings; + private final NestedObjectMapper subMappers; - private final String modelId; - - public SemanticTextFieldType(String name, String modelId, Map meta) { + public SemanticTextFieldType( + String name, + String modelId, + SemanticTextModelSettings modelSettings, + NestedObjectMapper subMappers, + Map meta + ) { super(name, false, false, false, TextSearchInfo.NONE, meta); - this.modelId = modelId; + this.inferenceId = modelId; + this.modelSettings = modelSettings; + this.subMappers = subMappers; } @Override @@ -109,7 +214,15 @@ public String typeName() { @Override public String getInferenceId() { - return modelId; + return inferenceId; + } + + public SemanticTextModelSettings getModelSettings() { + return modelSettings; + } + + public NestedObjectMapper getSubMappers() { + return subMappers; } @Override @@ -127,4 +240,59 @@ public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext throw new IllegalArgumentException("[semantic_text] fields do not support sorting, scripting or aggregating"); } } + + @Override + public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { + return super.syntheticFieldLoader(); + } + + private static Mapper.Builder createInferenceMapperBuilder( + String fieldName, + SemanticTextModelSettings modelSettings, + IndexVersion indexVersionCreated + ) { + return switch (modelSettings.taskType()) { + case SPARSE_EMBEDDING -> new SparseVectorFieldMapper.Builder(INFERENCE_CHUNKS_RESULTS); + case TEXT_EMBEDDING -> { + DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( + INFERENCE_CHUNKS_RESULTS, + indexVersionCreated + ); + SimilarityMeasure similarity = modelSettings.similarity(); + if (similarity != null) { + switch (similarity) { + case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); + case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); + default -> throw new IllegalArgumentException( + "Unknown similarity measure for field [" + fieldName + "] in model settings: " + similarity + ); + } + } + denseVectorMapperBuilder.dimensions(modelSettings.dimensions()); + yield denseVectorMapperBuilder; + } + default -> throw new IllegalArgumentException( + "Invalid [task_type] for [" + fieldName + "] in model settings: " + modelSettings.taskType().name() + ); + }; + } + + static boolean canMergeModelSettings( + SemanticTextModelSettings previous, + SemanticTextModelSettings current, + FieldMapper.Conflicts conflicts + ) { + if (Objects.equals(previous, current)) { + return true; + } + if (previous == null) { + return true; + } + if (current == null) { + conflicts.addConflict("model_settings", ""); + return false; + } + conflicts.addConflict("model_settings", ""); + return false; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java index 1b6bb22c0d6b5..b1d0511008db8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java @@ -7,73 +7,100 @@ package org.elasticsearch.xpack.inference.mapper; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.DeprecationHandler; +import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.support.MapXContentParser; import java.io.IOException; import java.util.HashMap; import java.util.Map; import java.util.Objects; +import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; + /** * Serialization class for specifying the settings of a model from semantic_text inference to field mapper. */ -public class SemanticTextModelSettings { +public class SemanticTextModelSettings implements ToXContentObject { public static final String NAME = "model_settings"; public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); - public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); public static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); public static final ParseField SIMILARITY_FIELD = new ParseField("similarity"); private final TaskType taskType; - private final String inferenceId; private final Integer dimensions; private final SimilarityMeasure similarity; - public SemanticTextModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) { + public SemanticTextModelSettings(Model model) { + this(model.getTaskType(), model.getServiceSettings().dimensions(), model.getServiceSettings().similarity()); + } + + public SemanticTextModelSettings(TaskType taskType, Integer dimensions, SimilarityMeasure similarity) { Objects.requireNonNull(taskType, "task type must not be null"); - Objects.requireNonNull(inferenceId, "inferenceId must not be null"); this.taskType = taskType; - this.inferenceId = inferenceId; this.dimensions = dimensions; this.similarity = similarity; - } - - public SemanticTextModelSettings(Model model) { - this( - model.getTaskType(), - model.getInferenceEntityId(), - model.getServiceSettings().dimensions(), - model.getServiceSettings().similarity() - ); + validate(); } public static SemanticTextModelSettings parse(XContentParser parser) throws IOException { return PARSER.apply(parser, null); } - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { - TaskType taskType = TaskType.fromString((String) args[0]); - String inferenceId = (String) args[1]; - Integer dimensions = (Integer) args[2]; - SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[3]); - return new SemanticTextModelSettings(taskType, inferenceId, dimensions, similarity); - }); + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + true, + args -> { + TaskType taskType = TaskType.fromString((String) args[0]); + Integer dimensions = (Integer) args[1]; + SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]); + return new SemanticTextModelSettings(taskType, dimensions, similarity); + } + ); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_TYPE_FIELD); - PARSER.declareString(ConstructingObjectParser.constructorArg(), INFERENCE_ID_FIELD); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), DIMENSIONS_FIELD); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), SIMILARITY_FIELD); } + public static SemanticTextModelSettings fromMap(Object node) { + if (node == null) { + return null; + } + try { + Map map = XContentMapValues.nodeMapValue(node, NAME); + if (map.containsKey(TASK_TYPE_FIELD.getPreferredName()) == false) { + throw new IllegalArgumentException( + "Failed to parse [" + NAME + "], required [" + TASK_TYPE_FIELD.getPreferredName() + "] is missing" + ); + } + XContentParser parser = new MapXContentParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + map, + XContentType.JSON + ); + return SemanticTextModelSettings.parse(parser); + } catch (Exception exc) { + throw new ElasticsearchException(exc); + } + } + public Map asMap() { Map attrsMap = new HashMap<>(); attrsMap.put(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); - attrsMap.put(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); if (dimensions != null) { attrsMap.put(DIMENSIONS_FIELD.getPreferredName(), dimensions); } @@ -87,10 +114,6 @@ public TaskType taskType() { return taskType; } - public String inferenceId() { - return inferenceId; - } - public Integer dimensions() { return dimensions; } @@ -98,4 +121,61 @@ public Integer dimensions() { public SimilarityMeasure similarity() { return similarity; } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); + if (dimensions != null) { + builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); + } + if (similarity != null) { + builder.field(SIMILARITY_FIELD.getPreferredName(), similarity); + } + return builder.endObject(); + } + + public void validate() { + switch (taskType) { + case TEXT_EMBEDDING: + if (dimensions == null) { + throw new IllegalArgumentException( + "required [" + DIMENSIONS_FIELD + "] field is missing for task_type [" + taskType.name() + "]" + ); + } + if (similarity == null) { + throw new IllegalArgumentException( + "required [" + SIMILARITY_FIELD + "] field is missing for task_type [" + taskType.name() + "]" + ); + } + break; + case SPARSE_EMBEDDING: + break; + + default: + throw new IllegalArgumentException( + "Wrong [" + + TASK_TYPE_FIELD.getPreferredName() + + "], expected " + + TEXT_EMBEDDING + + " or " + + SPARSE_EMBEDDING + + ", got " + + taskType.name() + ); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SemanticTextModelSettings that = (SemanticTextModelSettings) o; + return taskType == that.taskType && Objects.equals(dimensions, that.dimensions) && similarity == that.similarity; + } + + @Override + public int hashCode() { + return Objects.hash(taskType, dimensions, similarity); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java index a7d3fcce26116..bf3cc6334433a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -31,7 +31,7 @@ protected Collection> getPlugins() { public void testCreateIndexWithSemanticTextField() { final IndexService indexService = createIndex( "test", - client().admin().indices().prepareCreate("test").setMapping("field", "type=semantic_text,model_id=test_model") + client().admin().indices().prepareCreate("test").setMapping("field", "type=semantic_text,inference_id=test_model") ); assertEquals( indexService.getMetadata().getFieldInferenceMetadata().getFieldInferenceOptions().get("field").inferenceId(), @@ -46,7 +46,7 @@ public void testAddSemanticTextField() throws Exception { final ClusterService clusterService = getInstanceFromNode(ClusterService.class); final PutMappingClusterStateUpdateRequest request = new PutMappingClusterStateUpdateRequest(""" - { "properties": { "field": { "type": "semantic_text", "model_id": "test_model" }}}"""); + { "properties": { "field": { "type": "semantic_text", "inference_id": "test_model" }}}"""); request.indices(new Index[] { indexService.index() }); final var resultingState = ClusterStateTaskExecutorUtils.executeAndAssertSuccessful( clusterService.state(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 4a1825303b5a7..8b18cf74236a0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -32,7 +32,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.junit.After; @@ -51,8 +51,8 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomSparseEmbeddings; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomTextEmbeddings; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomSparseEmbeddings; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomTextEmbeddings; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.any; @@ -285,7 +285,7 @@ private static BulkItemRequest[] randomBulkItemRequest( final ChunkedInferenceServiceResults results; switch (taskType) { case TEXT_EMBEDDING: - results = randomTextEmbeddings(chunks); + results = randomTextEmbeddings(model, chunks); break; case SPARSE_EMBEDDING: @@ -296,10 +296,10 @@ private static BulkItemRequest[] randomBulkItemRequest( throw new AssertionError("Unknown task type " + taskType.name()); } model.putResult(text, results); - InferenceResultFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); + InferenceMetadataFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); } Map expectedDocMap = new LinkedHashMap<>(docMap); - expectedDocMap.put(InferenceResultFieldMapper.NAME, inferenceResultsMap); + expectedDocMap.put(InferenceMetadataFieldMapper.NAME, inferenceResultsMap); return new BulkItemRequest[] { new BulkItemRequest(id, new IndexRequest("index").source(docMap)), new BulkItemRequest(id, new IndexRequest("index").source(expectedDocMap)) }; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java similarity index 57% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java index b5d75b528c6ab..37e4e5e774bec 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.mapper; +import org.apache.lucene.document.FeatureField; +import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.Term; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; @@ -51,26 +53,28 @@ import java.util.Collection; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Consumer; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.CHUNKS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; -public class InferenceResultFieldMapperTests extends MetadataMapperTestCase { - private record SemanticTextInferenceResults(String fieldName, ChunkedInferenceServiceResults results, List text) {} +public class InferenceMetadataFieldMapperTests extends MetadataMapperTestCase { + private record SemanticTextInferenceResults(String fieldName, Model model, ChunkedInferenceServiceResults results, List text) {} - private record VisitedChildDocInfo(String path, int numChunks) {} + private record VisitedChildDocInfo(String path) {} private record SparseVectorSubfieldOptions(boolean include, boolean includeEmbedding, boolean includeIsTruncated) {} @Override protected String fieldName() { - return InferenceResultFieldMapper.NAME; + return InferenceMetadataFieldMapper.NAME; } @Override @@ -94,109 +98,129 @@ protected Collection getPlugins() { } public void testSuccessfulParse() throws IOException { - final String fieldName1 = randomAlphaOfLengthBetween(5, 15); - final String fieldName2 = randomAlphaOfLengthBetween(5, 15); - - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> { - addSemanticTextMapping(b, fieldName1, randomAlphaOfLength(8)); - addSemanticTextMapping(b, fieldName2, randomAlphaOfLength(8)); - })); - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of( - randomSemanticTextInferenceResults(fieldName1, List.of("a b", "c")), - randomSemanticTextInferenceResults(fieldName2, List.of("d e f")) + for (int depth = 1; depth < 4; depth++) { + final String fieldName1 = randomFieldName(depth); + final String fieldName2 = randomFieldName(depth + 1); + + Model model1 = randomModel(TaskType.SPARSE_EMBEDDING); + Model model2 = randomModel(TaskType.SPARSE_EMBEDDING); + XContentBuilder mapping = mapping(b -> { + addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId()); + addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId()); + }); + + MapperService mapperService = createMapperService(mapping); + SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName1, false); + SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName2, false); + DocumentMapper documentMapper = mapperService.documentMapper(); + ParsedDocument doc = documentMapper.parse( + source( + b -> addSemanticTextInferenceResults( + b, + List.of( + randomSemanticTextInferenceResults(fieldName1, model1, List.of("a b", "c")), + randomSemanticTextInferenceResults(fieldName2, model2, List.of("d e f")) + ) ) ) - ) - ); - - Set visitedChildDocs = new HashSet<>(); - Set expectedVisitedChildDocs = Set.of( - new VisitedChildDocInfo(fieldName1, 2), - new VisitedChildDocInfo(fieldName1, 1), - new VisitedChildDocInfo(fieldName2, 3) - ); - - List luceneDocs = doc.docs(); - assertEquals(4, luceneDocs.size()); - assertValidChildDoc(luceneDocs.get(0), doc.rootDoc(), visitedChildDocs); - assertValidChildDoc(luceneDocs.get(1), doc.rootDoc(), visitedChildDocs); - assertValidChildDoc(luceneDocs.get(2), doc.rootDoc(), visitedChildDocs); - assertEquals(doc.rootDoc(), luceneDocs.get(3)); - assertNull(luceneDocs.get(3).getParent()); - assertEquals(expectedVisitedChildDocs, visitedChildDocs); - - MapperService nestedMapperService = createMapperService(mapping(b -> { - addInferenceResultsNestedMapping(b, fieldName1); - addInferenceResultsNestedMapping(b, fieldName2); - })); - withLuceneIndex(nestedMapperService, iw -> iw.addDocuments(doc.docs()), reader -> { - NestedDocuments nested = new NestedDocuments( - nestedMapperService.mappingLookup(), - QueryBitSetProducer::new, - IndexVersion.current() - ); - LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); - - Set visitedNestedIdentities = new HashSet<>(); - Set expectedVisitedNestedIdentities = Set.of( - new SearchHit.NestedIdentity(fieldName1, 0, null), - new SearchHit.NestedIdentity(fieldName1, 1, null), - new SearchHit.NestedIdentity(fieldName2, 0, null) ); - assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); - assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); - - assertNull(leaf.advance(3)); - assertEquals(3, leaf.doc()); - assertEquals(3, leaf.rootDoc()); - assertNull(leaf.nestedIdentity()); - - IndexSearcher searcher = newSearcher(reader); - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a")), - 10 - ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a", "b")), - 10 - ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); + List luceneDocs = doc.docs(); + assertEquals(4, luceneDocs.size()); + for (int i = 0; i < 3; i++) { + assertEquals(doc.rootDoc(), luceneDocs.get(i).getParent()); } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName2, List.of("d")), - 10 + // nested docs are in reversed order + assertSparseFeatures(luceneDocs.get(0), fieldName1 + ".chunks.inference", 2); + assertSparseFeatures(luceneDocs.get(1), fieldName1 + ".chunks.inference", 1); + assertSparseFeatures(luceneDocs.get(2), fieldName2 + ".chunks.inference", 3); + assertEquals(doc.rootDoc(), luceneDocs.get(3)); + assertNull(luceneDocs.get(3).getParent()); + + withLuceneIndex(mapperService, iw -> iw.addDocuments(doc.docs()), reader -> { + NestedDocuments nested = new NestedDocuments( + mapperService.mappingLookup(), + QueryBitSetProducer::new, + IndexVersion.current() ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName2, List.of("z")), - 10 + LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); + + Set visitedNestedIdentities = new HashSet<>(); + Set expectedVisitedNestedIdentities = Set.of( + new SearchHit.NestedIdentity(fieldName1 + "." + CHUNKS, 0, null), + new SearchHit.NestedIdentity(fieldName1 + "." + CHUNKS, 1, null), + new SearchHit.NestedIdentity(fieldName2 + "." + CHUNKS, 0, null) ); - assertEquals(0, topDocs.totalHits.value); - } - }); + + assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); + assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); + + assertNull(leaf.advance(3)); + assertEquals(3, leaf.doc()); + assertEquals(3, leaf.rootDoc()); + assertNull(leaf.nestedIdentity()); + + IndexSearcher searcher = newSearcher(reader); + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName1 + "." + CHUNKS, + List.of("a") + ), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName1 + "." + CHUNKS, + List.of("a", "b") + ), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName2 + "." + CHUNKS, + List.of("d") + ), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName2 + "." + CHUNKS, + List.of("z") + ), + 10 + ); + assertEquals(0, topDocs.totalHits.value); + } + }); + } } public void testMissingSubfields() throws IOException { final String fieldName = randomAlphaOfLengthBetween(5, 15); + final Model model = randomModel(randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING); - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, fieldName, randomAlphaOfLength(8)))); + DocumentMapper documentMapper = createDocumentMapper( + mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId())) + ); { DocumentParsingException ex = expectThrows( @@ -206,7 +230,7 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), true, Map.of() @@ -224,7 +248,7 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), new SparseVectorSubfieldOptions(true, true, true), false, Map.of() @@ -242,7 +266,7 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), false, Map.of() @@ -259,15 +283,18 @@ public void testMissingSubfields() throws IOException { public void testExtraSubfields() throws IOException { final String fieldName = randomAlphaOfLengthBetween(5, 15); + final Model model = randomModel(randomBoolean() ? TaskType.SPARSE_EMBEDDING : TaskType.TEXT_EMBEDDING); final List semanticTextInferenceResultsList = List.of( - randomSemanticTextInferenceResults(fieldName, List.of("a b")) + randomSemanticTextInferenceResults(fieldName, model, List.of("a b")) ); - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, fieldName, randomAlphaOfLength(8)))); + DocumentMapper documentMapper = createDocumentMapper( + mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId())) + ); Consumer checkParsedDocument = d -> { Set visitedChildDocs = new HashSet<>(); - Set expectedVisitedChildDocs = Set.of(new VisitedChildDocInfo(fieldName, 2)); + Set expectedVisitedChildDocs = Set.of(new VisitedChildDocInfo(fieldName + "." + CHUNKS)); List luceneDocs = d.docs(); assertEquals(2, luceneDocs.size()); @@ -358,28 +385,97 @@ public void testMissingSemanticTextMapping() throws IOException { DocumentParsingException.class, DocumentParsingException.class, () -> documentMapper.parse( - source(b -> addSemanticTextInferenceResults(b, List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))))) + source( + b -> addSemanticTextInferenceResults( + b, + List.of( + randomSemanticTextInferenceResults( + fieldName, + randomModel(randomFrom(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)), + List.of("a b") + ) + ) + ) + ) ) ); assertThat( ex.getMessage(), containsString( - Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) + Strings.format("Field [%s] is not registered as a [%s] field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) ) ); } + public void testMissingInferenceId() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + IllegalArgumentException ex = expectThrows( + DocumentParsingException.class, + IllegalArgumentException.class, + () -> documentMapper.parse( + source( + b -> b.startObject(InferenceMetadataFieldMapper.NAME) + .startObject("field") + .startObject(SemanticTextModelSettings.NAME) + .field(SemanticTextModelSettings.TASK_TYPE_FIELD.getPreferredName(), TaskType.SPARSE_EMBEDDING) + .endObject() + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getMessage(), containsString("required [inference_id] is missing")); + } + + public void testMissingModelSettings() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + DocumentParsingException ex = expectThrows( + DocumentParsingException.class, + DocumentParsingException.class, + () -> documentMapper.parse( + source( + b -> b.startObject(InferenceMetadataFieldMapper.NAME) + .startObject("field") + .field(InferenceMetadataFieldMapper.INFERENCE_ID, "my_id") + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getMessage(), containsString("Missing required [model_settings] for field [field] of type [semantic_text]")); + } + + public void testMissingTaskType() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + DocumentParsingException ex = expectThrows( + DocumentParsingException.class, + DocumentParsingException.class, + () -> documentMapper.parse( + source( + b -> b.startObject(InferenceMetadataFieldMapper.NAME) + .startObject("field") + .field(InferenceMetadataFieldMapper.INFERENCE_ID, "my_id") + .startObject(SemanticTextModelSettings.NAME) + .endObject() + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getCause().getMessage(), containsString(" Failed to parse [model_settings], required [task_type] is missing")); + } + private static void addSemanticTextMapping(XContentBuilder mappingBuilder, String fieldName, String modelId) throws IOException { mappingBuilder.startObject(fieldName); mappingBuilder.field("type", SemanticTextFieldMapper.CONTENT_TYPE); - mappingBuilder.field("model_id", modelId); + mappingBuilder.field("inference_id", modelId); mappingBuilder.endObject(); } - public static ChunkedTextEmbeddingResults randomTextEmbeddings(List inputs) { + public static ChunkedTextEmbeddingResults randomTextEmbeddings(Model model, List inputs) { List chunks = new ArrayList<>(); for (String input : inputs) { - double[] values = new double[5]; + double[] values = new double[model.getServiceSettings().dimensions()]; for (int j = 0; j < values.length; j++) { values[j] = randomDouble(); } @@ -400,8 +496,17 @@ public static ChunkedSparseEmbeddingResults randomSparseEmbeddings(List return new ChunkedSparseEmbeddingResults(chunks); } - private static SemanticTextInferenceResults randomSemanticTextInferenceResults(String semanticTextFieldName, List chunks) { - return new SemanticTextInferenceResults(semanticTextFieldName, randomSparseEmbeddings(chunks), chunks); + private static SemanticTextInferenceResults randomSemanticTextInferenceResults( + String semanticTextFieldName, + Model model, + List chunks + ) { + ChunkedInferenceServiceResults chunkedResults = switch (model.getTaskType()) { + case TEXT_EMBEDDING -> randomTextEmbeddings(model, chunks); + case SPARSE_EMBEDDING -> randomSparseEmbeddings(chunks); + default -> throw new AssertionError("unkwnown task type: " + model.getTaskType().name()); + }; + return new SemanticTextInferenceResults(semanticTextFieldName, model, chunkedResults, chunks); } private static void addSemanticTextInferenceResults( @@ -425,16 +530,16 @@ private static void addSemanticTextInferenceResults( boolean includeTextSubfield, Map extraSubfields ) throws IOException { - Map inferenceResultsMap = new HashMap<>(); + Map inferenceResultsMap = new LinkedHashMap<>(); for (SemanticTextInferenceResults semanticTextInferenceResult : semanticTextInferenceResults) { - InferenceResultFieldMapper.applyFieldInference( + InferenceMetadataFieldMapper.applyFieldInference( inferenceResultsMap, semanticTextInferenceResult.fieldName, - randomModel(), + semanticTextInferenceResult.model, semanticTextInferenceResult.results ); Map optionsMap = (Map) inferenceResultsMap.get(semanticTextInferenceResult.fieldName); - List> fieldResultList = (List>) optionsMap.get(RESULTS); + List> fieldResultList = (List>) optionsMap.get(CHUNKS); for (var entry : fieldResultList) { if (includeTextSubfield == false) { entry.remove(INFERENCE_CHUNKS_TEXT); @@ -445,15 +550,26 @@ private static void addSemanticTextInferenceResults( entry.putAll(extraSubfields); } } - sourceBuilder.field(InferenceResultFieldMapper.NAME, inferenceResultsMap); + sourceBuilder.field(InferenceMetadataFieldMapper.NAME, inferenceResultsMap); + } + + static String randomFieldName(int numLevel) { + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < numLevel; i++) { + if (i > 0) { + builder.append('.'); + } + builder.append(randomAlphaOfLengthBetween(5, 15)); + } + return builder.toString(); } - private static Model randomModel() { + private static Model randomModel(TaskType taskType) { String serviceName = randomAlphaOfLengthBetween(5, 10); String inferenceId = randomAlphaOfLengthBetween(5, 10); return new TestModel( inferenceId, - TaskType.SPARSE_EMBEDDING, + taskType, serviceName, new TestModel.TestServiceSettings("my-model"), new TestModel.TestTaskSettings(randomIntBetween(1, 100)), @@ -461,29 +577,6 @@ private static Model randomModel() { ); } - private static void addInferenceResultsNestedMapping(XContentBuilder mappingBuilder, String semanticTextFieldName) throws IOException { - mappingBuilder.startObject(semanticTextFieldName); - { - mappingBuilder.field("type", "nested"); - mappingBuilder.startObject("properties"); - { - mappingBuilder.startObject(INFERENCE_CHUNKS_RESULTS); - { - mappingBuilder.field("type", "sparse_vector"); - } - mappingBuilder.endObject(); - mappingBuilder.startObject(INFERENCE_CHUNKS_TEXT); - { - mappingBuilder.field("type", "text"); - mappingBuilder.field("index", false); - } - mappingBuilder.endObject(); - } - mappingBuilder.endObject(); - } - mappingBuilder.endObject(); - } - private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLookup, String path, List tokens) { NestedObjectMapper mapper = nestedLookup.getNestedMappers().get(path); assertNotNull(mapper); @@ -503,12 +596,10 @@ private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLook private static void assertValidChildDoc( LuceneDocument childDoc, LuceneDocument expectedParent, - Set visitedChildDocs + Collection visitedChildDocs ) { assertEquals(expectedParent, childDoc.getParent()); - visitedChildDocs.add( - new VisitedChildDocInfo(childDoc.getPath(), childDoc.getFields(childDoc.getPath() + "." + INFERENCE_CHUNKS_RESULTS).size()) - ); + visitedChildDocs.add(new VisitedChildDocInfo(childDoc.getPath())); } private static void assertChildLeafNestedDocument( @@ -524,4 +615,15 @@ private static void assertChildLeafNestedDocument( assertNotNull(leaf.nestedIdentity()); visitedNestedIdentities.add(leaf.nestedIdentity()); } + + private static void assertSparseFeatures(LuceneDocument doc, String fieldName, int expectedCount) { + int count = 0; + for (IndexableField field : doc.getFields()) { + if (field instanceof FeatureField featureField) { + assertThat(featureField.name(), equalTo(fieldName)); + ++count; + } + } + assertThat(count, equalTo(expectedCount)); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index a3a705c9cc902..1b5311ac9effb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -11,11 +11,17 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.mapper.DocumentMapper; +import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MapperBuilderContext; import org.elasticsearch.index.mapper.MapperParsingException; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.MapperTestCase; +import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ParsedDocument; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; @@ -26,52 +32,12 @@ import java.util.List; import static java.util.Collections.singletonList; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.createSemanticFieldContext; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; public class SemanticTextFieldMapperTests extends MapperTestCase { - - public void testDefaults() throws Exception { - DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping)); - assertEquals(Strings.toString(fieldMapping(this::minimalMapping)), mapper.mappingSource().toString()); - - ParsedDocument doc1 = mapper.parse(source(this::writeField)); - List fields = doc1.rootDoc().getFields("field"); - - // No indexable fields - assertTrue(fields.isEmpty()); - } - - public void testModelIdNotPresent() throws IOException { - Exception e = expectThrows( - MapperParsingException.class, - () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text"))) - ); - assertThat(e.getMessage(), containsString("field [model_id] must be specified")); - } - - public void testCannotBeUsedInMultiFields() { - Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { - b.field("type", "text"); - b.startObject("fields"); - b.startObject("semantic"); - b.field("type", "semantic_text"); - b.endObject(); - b.endObject(); - }))); - assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields")); - } - - public void testUpdatesToModelIdNotSupported() throws IOException { - MapperService mapperService = createMapperService( - fieldMapping(b -> b.field("type", "semantic_text").field("model_id", "test_model")) - ); - Exception e = expectThrows( - IllegalArgumentException.class, - () -> merge(mapperService, fieldMapping(b -> b.field("type", "semantic_text").field("model_id", "another_model"))) - ); - assertThat(e.getMessage(), containsString("Cannot update parameter [model_id] from [test_model] to [another_model]")); - } - @Override protected Collection getPlugins() { return singletonList(new InferencePlugin(Settings.EMPTY)); @@ -79,7 +45,12 @@ protected Collection getPlugins() { @Override protected void minimalMapping(XContentBuilder b) throws IOException { - b.field("type", "semantic_text").field("model_id", "test_model"); + b.field("type", "semantic_text").field("inference_id", "test_model"); + } + + @Override + protected String minimalIsInvalidRoutingPathErrorMessage(Mapper mapper) { + return "cannot have nested fields when index is in [index.mode=time_series]"; } @Override @@ -115,4 +86,180 @@ protected SyntheticSourceSupport syntheticSourceSupport(boolean ignoreMalformed) protected IngestScriptSupport ingestScriptSupport() { throw new AssumptionViolatedException("not supported"); } + + public void testDefaults() throws Exception { + DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping)); + assertEquals(Strings.toString(fieldMapping(this::minimalMapping)), mapper.mappingSource().toString()); + + ParsedDocument doc1 = mapper.parse(source(this::writeField)); + List fields = doc1.rootDoc().getFields("field"); + + // No indexable fields + assertTrue(fields.isEmpty()); + } + + public void testInferenceIdNotPresent() throws IOException { + Exception e = expectThrows( + MapperParsingException.class, + () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text"))) + ); + assertThat(e.getMessage(), containsString("field [inference_id] must be specified")); + } + + public void testCannotBeUsedInMultiFields() { + Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { + b.field("type", "text"); + b.startObject("fields"); + b.startObject("semantic"); + b.field("type", "semantic_text"); + b.endObject(); + b.endObject(); + }))); + assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields")); + } + + public void testUpdatesToInferenceIdNotSupported() throws IOException { + String fieldName = randomAlphaOfLengthBetween(5, 15); + MapperService mapperService = createMapperService( + mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()) + ); + assertSemanticTextField(mapperService, fieldName, false); + Exception e = expectThrows( + IllegalArgumentException.class, + () -> merge( + mapperService, + mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "another_model").endObject()) + ) + ); + assertThat(e.getMessage(), containsString("Cannot update parameter [inference_id] from [test_model] to [another_model]")); + } + + public void testUpdateModelSettings() throws IOException { + for (int depth = 1; depth < 5; depth++) { + String fieldName = InferenceMetadataFieldMapperTests.randomFieldName(depth); + MapperService mapperService = createMapperService( + mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()) + ); + assertSemanticTextField(mapperService, fieldName, false); + { + Exception exc = expectThrows( + MapperParsingException.class, + () -> merge( + mapperService, + mapping( + b -> b.startObject(fieldName) + .field("type", "semantic_text") + .field("inference_id", "test_model") + .startObject("model_settings") + .field("inference_id", "test_model") + .endObject() + .endObject() + ) + ) + ); + assertThat(exc.getMessage(), containsString("Failed to parse [model_settings], required [task_type] is missing")); + } + { + merge( + mapperService, + mapping( + b -> b.startObject(fieldName) + .field("type", "semantic_text") + .field("inference_id", "test_model") + .startObject("model_settings") + .field("task_type", "sparse_embedding") + .endObject() + .endObject() + ) + ); + assertSemanticTextField(mapperService, fieldName, true); + } + { + Exception exc = expectThrows( + IllegalArgumentException.class, + () -> merge( + mapperService, + mapping( + b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject() + ) + ) + ); + assertThat( + exc.getMessage(), + containsString("Cannot update parameter [model_settings] " + "from [{\"task_type\":\"sparse_embedding\"}] to [null]") + ); + } + { + Exception exc = expectThrows( + IllegalArgumentException.class, + () -> merge( + mapperService, + mapping( + b -> b.startObject(fieldName) + .field("type", "semantic_text") + .field("inference_id", "test_model") + .startObject("model_settings") + .field("task_type", "text_embedding") + .field("dimensions", 10) + .field("similarity", "cosine") + .endObject() + .endObject() + ) + ) + ); + assertThat( + exc.getMessage(), + containsString( + "Cannot update parameter [model_settings] " + + "from [{\"task_type\":\"sparse_embedding\"}] " + + "to [{\"task_type\":\"text_embedding\",\"dimensions\":10,\"similarity\":\"cosine\"}]" + ) + ); + } + } + } + + static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) { + InferenceMetadataFieldMapper.SemanticTextMapperContext res = createSemanticFieldContext( + MapperBuilderContext.root(false, false), + mapperService.mappingLookup().getMapping().getRoot(), + fieldName.split("\\.") + ); + Mapper mapper = res.mapper(); + assertNotNull(mapper); + assertThat(mapper, instanceOf(SemanticTextFieldMapper.class)); + SemanticTextFieldMapper semanticFieldMapper = (SemanticTextFieldMapper) mapper; + + var fieldType = mapperService.fieldType(fieldName); + assertNotNull(fieldType); + assertThat(fieldType, instanceOf(SemanticTextFieldMapper.SemanticTextFieldType.class)); + SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldType; + assertTrue(semanticFieldMapper.fieldType() == semanticTextFieldType); + assertTrue(semanticFieldMapper.getSubMappers() == semanticTextFieldType.getSubMappers()); + assertTrue(semanticFieldMapper.getModelSettings() == semanticTextFieldType.getModelSettings()); + + NestedObjectMapper nestedObjectMapper = mapperService.mappingLookup() + .nestedLookup() + .getNestedMappers() + .get(fieldName + "." + InferenceMetadataFieldMapper.CHUNKS); + assertThat(nestedObjectMapper, equalTo(semanticFieldMapper.getSubMappers())); + Mapper textMapper = nestedObjectMapper.getMapper(InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT); + assertNotNull(textMapper); + assertThat(textMapper, instanceOf(KeywordFieldMapper.class)); + KeywordFieldMapper textFieldMapper = (KeywordFieldMapper) textMapper; + assertFalse(textFieldMapper.fieldType().isIndexed()); + assertFalse(textFieldMapper.fieldType().hasDocValues()); + if (expectedModelSettings) { + assertNotNull(semanticFieldMapper.getModelSettings()); + Mapper inferenceMapper = nestedObjectMapper.getMapper(InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS); + assertNotNull(inferenceMapper); + switch (semanticFieldMapper.getModelSettings().taskType()) { + case SPARSE_EMBEDDING -> assertThat(inferenceMapper, instanceOf(SparseVectorFieldMapper.class)); + case TEXT_EMBEDDING -> assertThat(inferenceMapper, instanceOf(DenseVectorFieldMapper.class)); + default -> throw new AssertionError("Invalid task type"); + } + } else { + assertNull(semanticFieldMapper.getModelSettings()); + } + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java index 75e7ca12c1d56..b64485a3d3fb2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ToXContentObject; @@ -121,6 +122,16 @@ public void writeTo(StreamOutput out) throws IOException { public ToXContentObject getFilteredXContentObject() { return this; } + + @Override + public SimilarityMeasure similarity() { + return SimilarityMeasure.COSINE; + } + + @Override + public Integer dimensions() { + return 100; + } } public record TestTaskSettings(Integer temperature) implements TaskSettings { diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml index 6008ebbcbedf8..528003e278aeb 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_inference.yml @@ -27,6 +27,7 @@ setup: "service_settings": { "model": "my_model", "dimensions": 10, + "similarity": "cosine", "api_key": "abc64" }, "task_settings": { @@ -41,10 +42,10 @@ setup: properties: inference_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id another_inference_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id non_inference_field: type: text @@ -56,10 +57,10 @@ setup: properties: inference_field: type: semantic_text - model_id: dense-inference-id + inference_id: dense-inference-id another_inference_field: type: semantic_text - model_id: dense-inference-id + inference_id: dense-inference-id non_inference_field: type: text @@ -83,11 +84,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.results.0.text: "inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - exists: _source._inference.inference_field.results.0.inference - - exists: _source._inference.another_inference_field.results.0.inference + - exists: _source._inference.inference_field.chunks.0.inference + - exists: _source._inference.another_inference_field.chunks.0.inference --- "text expansion documents do not create new mappings": @@ -120,11 +121,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.results.0.text: "inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - exists: _source._inference.inference_field.results.0.inference - - exists: _source._inference.another_inference_field.results.0.inference + - exists: _source._inference.inference_field.chunks.0.inference + - exists: _source._inference.another_inference_field.chunks.0.inference --- @@ -154,8 +155,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._inference.inference_field.results.0.inference: inference_field_embedding } - - set: { _source._inference.another_inference_field.results.0.inference: another_inference_field_embedding } + - set: { _source._inference.inference_field.chunks.0.inference: inference_field_embedding } + - set: { _source._inference.another_inference_field.chunks.0.inference: another_inference_field_embedding } - do: update: @@ -174,11 +175,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "another non inference test" } - - match: { _source._inference.inference_field.results.0.text: "inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - match: { _source._inference.inference_field.results.0.inference: $inference_field_embedding } - - match: { _source._inference.another_inference_field.results.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } --- "Updating semantic_text fields recalculates embeddings": @@ -214,8 +215,8 @@ setup: - match: { _source.another_inference_field: "another updated inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.results.0.text: "updated inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another updated inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "updated inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another updated inference test" } --- "Reindex works for semantic_text fields": @@ -233,8 +234,8 @@ setup: index: test-sparse-index id: doc_1 - - set: { _source._inference.inference_field.results.0.inference: inference_field_embedding } - - set: { _source._inference.another_inference_field.results.0.inference: another_inference_field_embedding } + - set: { _source._inference.inference_field.chunks.0.inference: inference_field_embedding } + - set: { _source._inference.another_inference_field.chunks.0.inference: another_inference_field_embedding } - do: indices.refresh: { } @@ -247,10 +248,10 @@ setup: properties: inference_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id another_inference_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id non_inference_field: type: text @@ -271,11 +272,11 @@ setup: - match: { _source.another_inference_field: "another inference test" } - match: { _source.non_inference_field: "non inference test" } - - match: { _source._inference.inference_field.results.0.text: "inference test" } - - match: { _source._inference.another_inference_field.results.0.text: "another inference test" } + - match: { _source._inference.inference_field.chunks.0.text: "inference test" } + - match: { _source._inference.another_inference_field.chunks.0.text: "another inference test" } - - match: { _source._inference.inference_field.results.0.inference: $inference_field_embedding } - - match: { _source._inference.another_inference_field.results.0.inference: $another_inference_field_embedding } + - match: { _source._inference.inference_field.chunks.0.inference: $inference_field_embedding } + - match: { _source._inference.another_inference_field.chunks.0.inference: $another_inference_field_embedding } --- "Fails for non-existent model": @@ -287,7 +288,7 @@ setup: properties: inference_field: type: semantic_text - model_id: non-existing-inference-id + inference_id: non-existing-inference-id non_inference_field: type: text diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml index 2c69f49218091..27f233436b925 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -27,7 +27,8 @@ setup: "service_settings": { "model": "my_model", "dimensions": 10, - "api_key": "abc64" + "api_key": "abc64", + "similarity": "cosine" }, "task_settings": { } @@ -41,10 +42,10 @@ setup: properties: sparse_field: type: semantic_text - model_id: sparse-inference-id + inference_id: sparse-inference-id dense_field: type: semantic_text - model_id: dense-inference-id + inference_id: dense-inference-id non_inference_field: type: text @@ -55,25 +56,7 @@ setup: index: test-index id: doc_1 body: - non_inference_field: "you know, for testing" - _inference: - sparse_field: - model_settings: - inference_id: sparse-inference-id - task_type: sparse_embedding - results: - - text: "inference test" - inference: - feature_1: 0.1 - feature_2: 0.2 - feature_3: 0.3 - feature_4: 0.4 - - text: "another inference test" - inference: - feature_1: 0.1 - feature_2: 0.2 - feature_3: 0.3 - feature_4: 0.4 + sparse_field: "you know, for testing" --- "Dense vector results format": @@ -82,72 +65,4 @@ setup: index: test-index id: doc_1 body: - non_inference_field: "you know, for testing" - _inference: - dense_field: - model_settings: - inference_id: sparse-inference-id - task_type: text_embedding - dimensions: 5 - similarity: cosine - results: - - text: "inference test" - inference: [0.1, 0.2, 0.3, 0.4, 0.5] - - text: "another inference test" - inference: [-0.1, -0.2, -0.3, -0.4, -0.5] - ---- -"Model settings inference id not included": - - do: - catch: /Required \[inference_id\]/ - index: - index: test-index - id: doc_1 - body: - non_inference_field: "you know, for testing" - _inference: - sparse_field: - model_settings: - task_type: sparse_embedding - results: - - text: "inference test" - inference: - feature_1: 0.1 - ---- -"Model settings task type not included": - - do: - catch: /Required \[task_type\]/ - index: - index: test-index - id: doc_1 - body: - non_inference_field: "you know, for testing" - _inference: - sparse_field: - model_settings: - inference_id: sparse-inference-id - results: - - text: "inference test" - inference: - feature_1: 0.1 - ---- -"Model settings dense vector dimensions not included": - - do: - catch: /Model settings for field \[dense_field\] must contain dimensions/ - index: - index: test-index - id: doc_1 - body: - non_inference_field: "you know, for testing" - _inference: - dense_field: - model_settings: - inference_id: sparse-inference-id - task_type: text_embedding - results: - - text: "inference test" - inference: [0.1, 0.2, 0.3, 0.4, 0.5] - - text: "another inference test" - inference: [-0.1, -0.2, -0.3, -0.4, -0.5] + dense_field: "you know, for testing"