From b2b863579d0c301f3d781f503d7b35996b449e69 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Tue, 19 Mar 2024 18:44:29 +0000 Subject: [PATCH] Refactor the semantic_text field so that it can registers all the sub-fields in the mapping --- .../index/mapper/FieldMapper.java | 8 +- .../vectors/SparseVectorFieldMapper.java | 7 +- .../xpack/inference/InferencePlugin.java | 4 +- .../ShardBulkInferenceActionFilter.java | 8 +- .../mapper/InferenceMetadataFieldMapper.java | 385 ++++++++++++++++++ .../mapper/InferenceResultFieldMapper.java | 372 ----------------- .../mapper/SemanticTextFieldMapper.java | 197 ++++++++- .../mapper/SemanticTextModelSettings.java | 45 +- .../ShardBulkInferenceActionFilterTests.java | 10 +- ...=> InferenceMetadataFieldMapperTests.java} | 309 +++++++------- .../mapper/SemanticTextFieldMapperTests.java | 6 + .../20_semantic_text_field_mapper.yml | 4 +- 12 files changed, 815 insertions(+), 540 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/{InferenceResultFieldMapperTests.java => InferenceMetadataFieldMapperTests.java} (66%) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java index 71fd9edd49903..f9354025cab49 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java @@ -1176,7 +1176,7 @@ public static final class Conflicts { private final String mapperName; private final List conflicts = new ArrayList<>(); - Conflicts(String mapperName) { + public Conflicts(String mapperName) { this.mapperName = mapperName; } @@ -1188,7 +1188,11 @@ void addConflict(String parameter, String existing, String toMerge) { conflicts.add("Cannot update parameter [" + parameter + "] from [" + existing + "] to [" + toMerge + "]"); } - void check() { + public boolean hasConflicts() { + return conflicts.isEmpty() == false; + } + + public void check() { if (conflicts.isEmpty()) { return; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java index 6532abed19044..58286d34dada1 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java @@ -171,9 +171,12 @@ public void parse(DocumentParserContext context) throws IOException { } String feature = null; + boolean origIsWithLeafObject = context.path().isWithinLeafObject(); try { // make sure that we don't expand dots in field names while parsing - context.path().setWithinLeafObject(true); + if (context.path().isWithinLeafObject() == false) { + context.path().setWithinLeafObject(true); + } for (Token token = context.parser().nextToken(); token != Token.END_OBJECT; token = context.parser().nextToken()) { if (token == Token.FIELD_NAME) { feature = context.parser().currentName(); @@ -207,7 +210,7 @@ public void parse(DocumentParserContext context) throws IOException { context.addToFieldNames(fieldType().name()); } } finally { - context.path().setWithinLeafObject(false); + context.path().setWithinLeafObject(origIsWithLeafObject); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 994207766f2a6..24c1950be1915 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -55,7 +55,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory; import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; @@ -285,7 +285,7 @@ public Map getMappers() { @Override public Map getMetadataMappers() { - return Map.of(InferenceResultFieldMapper.NAME, InferenceResultFieldMapper.PARSER); + return Map.of(InferenceMetadataFieldMapper.NAME, InferenceMetadataFieldMapper.PARSER); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index e679d3c970abf..47fae274095e4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -36,7 +36,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; -import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; @@ -49,7 +49,7 @@ /** * An {@link ActionFilter} that performs inference on {@link BulkShardRequest} asynchronously and stores the results in - * the individual {@link BulkItemRequest}. The results are then consumed by the {@link InferenceResultFieldMapper} + * the individual {@link BulkItemRequest}. The results are then consumed by the {@link InferenceMetadataFieldMapper} * in the subsequent {@link TransportShardBulkAction} downstream. */ public class ShardBulkInferenceActionFilter implements ActionFilter { @@ -261,10 +261,10 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons Map newDocMap = indexRequest.sourceAsMap(); Map inferenceMap = new LinkedHashMap<>(); // ignore the existing inference map if any - newDocMap.put(InferenceResultFieldMapper.NAME, inferenceMap); + newDocMap.put(InferenceMetadataFieldMapper.NAME, inferenceMap); for (FieldInferenceResponse fieldResponse : response.responses()) { try { - InferenceResultFieldMapper.applyFieldInference( + InferenceMetadataFieldMapper.applyFieldInference( inferenceMap, fieldResponse.field(), fieldResponse.model(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java new file mode 100644 index 0000000000000..831509288696f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java @@ -0,0 +1,385 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.apache.lucene.search.Query; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.index.mapper.DocumentParserContext; +import org.elasticsearch.index.mapper.DocumentParsingException; +import org.elasticsearch.index.mapper.FieldMapper; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.MetadataFieldMapper; +import org.elasticsearch.index.mapper.NestedObjectMapper; +import org.elasticsearch.index.mapper.ObjectMapper; +import org.elasticsearch.index.mapper.SourceLoader; +import org.elasticsearch.index.mapper.SourceValueFetcher; +import org.elasticsearch.index.mapper.TextSearchInfo; +import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.DeprecationHandler; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.XContentLocation; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.support.MapXContentParser; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * A mapper for the {@code _inference} field. + *
+ *
+ * This mapper works in tandem with {@link SemanticTextFieldMapper semantic_text} fields to index inference results. + * The inference results for {@code semantic_text} fields are written to {@code _source} by an upstream process like so: + *
+ *
+ *
+ * {
+ *     "_source": {
+ *         "my_semantic_text_field": "these are not the droids you're looking for",
+ *         "_inference": {
+ *             "my_semantic_text_field": {
+ *                  "model_settings": {
+ *                      "inference_id": "my_inference_id",
+ *                      "task_type": "SPARSE_EMBEDDING"
+ *                  },
+ *                  "results" [
+ *                      {
+ *                          "inference": {
+ *                              "lucas": 0.05212344,
+ *                              "ty": 0.041213956,
+ *                              "dragon": 0.50991,
+ *                              "type": 0.23241979,
+ *                              "dr": 1.9312073,
+ *                              "##o": 0.2797593
+ *                          },
+ *                          "text": "these are not the droids you're looking for"
+ *                      }
+ *                  ]
+ *              }
+ *          }
+ *      }
+ * }
+ * 
+ * + * This mapper parses the contents of the {@code _semantic_text_inference} field and indexes it as if the mapping were configured like so: + *
+ *
+ *
+ * {
+ *     "mappings": {
+ *         "properties": {
+ *             "my_semantic_text_field": {
+ *                 "type": "nested",
+ *                 "properties": {
+ *                     "sparse_embedding": {
+ *                         "type": "sparse_vector"
+ *                     },
+ *                     "text": {
+ *                         "type": "text",
+ *                         "index": false
+ *                     }
+ *                 }
+ *             }
+ *         }
+ *     }
+ * }
+ * 
+ */ +public class InferenceMetadataFieldMapper extends MetadataFieldMapper { + public static final String NAME = "_inference"; + public static final String CONTENT_TYPE = "_inference"; + + public static final String RESULTS = "results"; + public static final String INFERENCE_CHUNKS_RESULTS = "inference"; + public static final String INFERENCE_CHUNKS_TEXT = "text"; + + public static final TypeParser PARSER = new FixedTypeParser(c -> new InferenceMetadataFieldMapper()); + + private static final Logger logger = LogManager.getLogger(InferenceMetadataFieldMapper.class); + + private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); + + static class SemanticTextInferenceFieldType extends MappedFieldType { + private static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); + + SemanticTextInferenceFieldType() { + super(NAME, true, false, false, TextSearchInfo.NONE, Collections.emptyMap()); + } + + @Override + public String typeName() { + return CONTENT_TYPE; + } + + @Override + public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + return SourceValueFetcher.identity(name(), context, format); + } + + @Override + public Query termQuery(Object value, SearchExecutionContext context) { + return null; + } + } + + public InferenceMetadataFieldMapper() { + super(SemanticTextInferenceFieldType.INSTANCE); + } + + @Override + protected void parseCreateField(DocumentParserContext context) throws IOException { + XContentParser parser = context.parser(); + failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.START_OBJECT); + MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false); + boolean origWithLeafObject = context.path().isWithinLeafObject(); + try { + // make sure that we don't expand dots in field names while parsing + context.path().setWithinLeafObject(true); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.FIELD_NAME); + parseSingleField(context, mapperBuilderContext); + } + } finally { + context.path().setWithinLeafObject(origWithLeafObject); + } + } + + private SemanticTextFieldMapper updateSemanticTextFieldMapper( + DocumentParserContext docContext, + MapperBuilderContext mapperBuilderContext, + SemanticTextFieldMapper original, + SemanticTextModelSettings modelSettings, + XContentLocation xContentLocation + ) { + if (modelSettings.inferenceId().equals(original.fieldType().getInferenceModel()) == false) { + throw new DocumentParsingException( + xContentLocation, + "Model settings for field [" + + original.fieldType().name() + + "] is already set to [" + + original.fieldType().getInferenceModel() + + "], got [" + + modelSettings.inferenceId() + + "]" + ); + } + if (modelSettings.taskType() == TaskType.TEXT_EMBEDDING && modelSettings.dimensions() == null) { + throw new DocumentParsingException( + xContentLocation, + "Model settings for field [" + original.fieldType().name() + "] must contain dimensions" + ); + } + + if (original.getModelSettings() == null) { + SemanticTextFieldMapper newMapper = new SemanticTextFieldMapper.Builder( + original.name(), + docContext.indexSettings().getIndexVersionCreated(), + docContext.indexAnalyzers() + ).setModelId(modelSettings.inferenceId()).setModelSettings(modelSettings).build(mapperBuilderContext); + docContext.addDynamicMapper(newMapper); + return newMapper; + } else { + var conflicts = new Conflicts(original.name()); + SemanticTextModelSettings.checkCompatibility(original.getModelSettings(), modelSettings, conflicts); + try { + conflicts.check(); + } catch (Exception exc) { + throw new DocumentParsingException(xContentLocation, "Failed to update field [" + original.name() + "]", exc); + } + } + return original; + } + + private record FieldMapperAndParent(ObjectMapper parent, Mapper mapper) {} + + private FieldMapperAndParent findFieldMapper(ObjectMapper mapper, String fullName) { + String[] pathElements = fullName.split("\\."); + for (int i = 0; i < pathElements.length - 1; i++) { + Mapper next = mapper.getMapper(pathElements[i]); + if (next == null || next instanceof ObjectMapper == false) { + return null; + } + mapper = (ObjectMapper) next; + } + return new FieldMapperAndParent(mapper, mapper.getMapper(pathElements[pathElements.length - 1])); + } + + @SuppressWarnings("unchecked") + private void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) throws IOException { + XContentParser parser = context.parser(); + String fieldName = parser.currentName(); + var res = findFieldMapper(context.root(), fieldName); + if (res == null || res.mapper == null || res.mapper instanceof SemanticTextFieldMapper == false) { + throw new DocumentParsingException( + parser.getTokenLocation(), + Strings.format("Field [%s] is not registered as a [%s] field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) + ); + } + parser.nextToken(); + failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.START_OBJECT); + XContentLocation xContentLocation = parser.getTokenLocation(); + + Map map = parser.mapOrdered(); + Map modelSettingsMap = (Map) map.remove(SemanticTextModelSettings.NAME); + var modelSettings = SemanticTextModelSettings.parse( + XContentHelper.mapToXContentParser(XContentParserConfiguration.EMPTY, modelSettingsMap) + ); + var fieldMapper = updateSemanticTextFieldMapper( + context, + mapperBuilderContext, + (SemanticTextFieldMapper) res.mapper, + modelSettings, + xContentLocation + ); + XContentParser subParser = new MapXContentParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + map, + XContentType.JSON + ); + DocumentParserContext mapContext = context.switchParser(subParser); + parseFieldInferenceObject(xContentLocation, subParser, mapContext, fieldMapper.getNestedField()); + } + + private void parseFieldInferenceObject( + XContentLocation xContentLocation, + XContentParser parser, + DocumentParserContext context, + NestedObjectMapper nestedMapper + ) throws IOException { + parser.nextToken(); + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_OBJECT); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + switch (parser.currentName()) { + case RESULTS -> parseResultsList(xContentLocation, parser, context, nestedMapper); + default -> throw new DocumentParsingException(xContentLocation, "Unknown field name " + parser.currentName()); + } + } + } + + private void parseResultsList( + XContentLocation xContentLocation, + XContentParser parser, + DocumentParserContext context, + NestedObjectMapper nestedMapper + ) throws IOException { + parser.nextToken(); + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_ARRAY); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_ARRAY; token = parser.nextToken()) { + DocumentParserContext subContext = context.createNestedContext(nestedMapper); + parseResultsObject(xContentLocation, parser, subContext, nestedMapper); + } + } + + private void parseResultsObject( + XContentLocation xContentLocation, + XContentParser parser, + DocumentParserContext context, + NestedObjectMapper nestedMapper + ) throws IOException { + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.START_OBJECT); + Set visited = new HashSet<>(); + for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { + failIfTokenIsNot(xContentLocation, parser, XContentParser.Token.FIELD_NAME); + visited.add(parser.currentName()); + FieldMapper fieldMapper = (FieldMapper) nestedMapper.getMapper(parser.currentName()); + if (fieldMapper == null) { + logger.debug("Skipping indexing of unrecognized field name [" + parser.currentName() + "]"); + advancePastCurrentFieldName(xContentLocation, parser); + continue; + } + parser.nextToken(); + fieldMapper.parse(context); + } + if (visited.containsAll(REQUIRED_SUBFIELDS) == false) { + Set missingSubfields = REQUIRED_SUBFIELDS.stream() + .filter(s -> visited.contains(s) == false) + .collect(Collectors.toSet()); + throw new DocumentParsingException(xContentLocation, "Missing required subfields: " + missingSubfields); + } + } + + private static void failIfTokenIsNot(XContentLocation xContentLocation, XContentParser parser, XContentParser.Token expected) { + if (parser.currentToken() != expected) { + throw new DocumentParsingException(xContentLocation, "Expected a " + expected.toString() + ", got " + parser.currentToken()); + } + } + + private static void advancePastCurrentFieldName(XContentLocation xContentLocation, XContentParser parser) throws IOException { + assert parser.currentToken() == XContentParser.Token.FIELD_NAME; + XContentParser.Token token = parser.nextToken(); + if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) { + parser.skipChildren(); + } else if (token.isValue() == false && token != XContentParser.Token.VALUE_NULL) { + throw new DocumentParsingException(xContentLocation, "Expected a START_* or VALUE_*, got " + token); + } + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { + return SourceLoader.SyntheticFieldLoader.NOTHING; + } + + public static void applyFieldInference( + Map inferenceMap, + String field, + Model model, + ChunkedInferenceServiceResults results + ) throws ElasticsearchException { + List> chunks = new ArrayList<>(); + if (results instanceof ChunkedSparseEmbeddingResults textExpansionResults) { + for (var chunk : textExpansionResults.getChunkedResults()) { + chunks.add(chunk.asMap()); + } + } else if (results instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { + for (var chunk : textEmbeddingResults.getChunks()) { + chunks.add(chunk.asMap()); + } + } else { + throw new ElasticsearchStatusException( + "Invalid inference results format for field [{}] with inference id [{}], got {}", + RestStatus.BAD_REQUEST, + field, + model.getInferenceEntityId(), + results.getWriteableName() + ); + } + Map fieldMap = new LinkedHashMap<>(); + fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); + fieldMap.put(InferenceMetadataFieldMapper.RESULTS, chunks); + inferenceMap.put(field, fieldMap); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java deleted file mode 100644 index 2ede5419ab74e..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapper.java +++ /dev/null @@ -1,372 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.mapper; - -import org.apache.lucene.search.Query; -import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.common.Strings; -import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.mapper.DocumentParserContext; -import org.elasticsearch.index.mapper.DocumentParsingException; -import org.elasticsearch.index.mapper.FieldMapper; -import org.elasticsearch.index.mapper.MappedFieldType; -import org.elasticsearch.index.mapper.Mapper; -import org.elasticsearch.index.mapper.MapperBuilderContext; -import org.elasticsearch.index.mapper.MetadataFieldMapper; -import org.elasticsearch.index.mapper.NestedObjectMapper; -import org.elasticsearch.index.mapper.ObjectMapper; -import org.elasticsearch.index.mapper.SourceLoader; -import org.elasticsearch.index.mapper.SourceValueFetcher; -import org.elasticsearch.index.mapper.TextFieldMapper; -import org.elasticsearch.index.mapper.TextSearchInfo; -import org.elasticsearch.index.mapper.ValueFetcher; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; -import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * A mapper for the {@code _semantic_text_inference} field. - *
- *
- * This mapper works in tandem with {@link SemanticTextFieldMapper semantic_text} fields to index inference results. - * The inference results for {@code semantic_text} fields are written to {@code _source} by an upstream process like so: - *
- *
- *
- * {
- *     "_source": {
- *         "my_semantic_text_field": "these are not the droids you're looking for",
- *         "_inference": {
- *             "my_semantic_text_field": [
- *                 {
- *                     "sparse_embedding": {
- *                          "lucas": 0.05212344,
- *                          "ty": 0.041213956,
- *                          "dragon": 0.50991,
- *                          "type": 0.23241979,
- *                          "dr": 1.9312073,
- *                          "##o": 0.2797593
- *                     },
- *                     "text": "these are not the droids you're looking for"
- *                 }
- *             ]
- *         }
- *     }
- * }
- * 
- * - * This mapper parses the contents of the {@code _semantic_text_inference} field and indexes it as if the mapping were configured like so: - *
- *
- *
- * {
- *     "mappings": {
- *         "properties": {
- *             "my_semantic_text_field": {
- *                 "type": "nested",
- *                 "properties": {
- *                     "sparse_embedding": {
- *                         "type": "sparse_vector"
- *                     },
- *                     "text": {
- *                         "type": "text",
- *                         "index": false
- *                     }
- *                 }
- *             }
- *         }
- *     }
- * }
- * 
- */ -public class InferenceResultFieldMapper extends MetadataFieldMapper { - public static final String NAME = "_inference"; - public static final String CONTENT_TYPE = "_inference"; - - public static final String RESULTS = "results"; - public static final String INFERENCE_CHUNKS_RESULTS = "inference"; - public static final String INFERENCE_CHUNKS_TEXT = "text"; - - public static final TypeParser PARSER = new FixedTypeParser(c -> new InferenceResultFieldMapper()); - - private static final Logger logger = LogManager.getLogger(InferenceResultFieldMapper.class); - - private static final Set REQUIRED_SUBFIELDS = Set.of(INFERENCE_CHUNKS_TEXT, INFERENCE_CHUNKS_RESULTS); - - static class SemanticTextInferenceFieldType extends MappedFieldType { - private static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); - - SemanticTextInferenceFieldType() { - super(NAME, true, false, false, TextSearchInfo.NONE, Collections.emptyMap()); - } - - @Override - public String typeName() { - return CONTENT_TYPE; - } - - @Override - public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.identity(name(), context, format); - } - - @Override - public Query termQuery(Object value, SearchExecutionContext context) { - return null; - } - } - - public InferenceResultFieldMapper() { - super(SemanticTextInferenceFieldType.INSTANCE); - } - - @Override - protected void parseCreateField(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - - parseAllFields(context); - } - - private static void parseAllFields(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - parseSingleField(context, mapperBuilderContext); - } - } - - private static void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) throws IOException { - - XContentParser parser = context.parser(); - String fieldName = parser.currentName(); - Mapper mapper = context.getMapper(fieldName); - if (mapper == null || SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) { - throw new DocumentParsingException( - parser.getTokenLocation(), - Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) - ); - } - parser.nextToken(); - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - parser.nextToken(); - SemanticTextModelSettings modelSettings = SemanticTextModelSettings.parse(parser); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - String currentName = parser.currentName(); - if (RESULTS.equals(currentName)) { - NestedObjectMapper nestedObjectMapper = createInferenceResultsObjectMapper( - context, - mapperBuilderContext, - fieldName, - modelSettings - ); - parseFieldInferenceChunks(context, mapperBuilderContext, fieldName, modelSettings, nestedObjectMapper); - } else { - logger.debug("Skipping unrecognized field name [" + currentName + "]"); - advancePastCurrentFieldName(parser); - } - } - } - - private static void parseFieldInferenceChunks( - DocumentParserContext context, - MapperBuilderContext mapperBuilderContext, - String fieldName, - SemanticTextModelSettings modelSettings, - NestedObjectMapper nestedObjectMapper - ) throws IOException { - XContentParser parser = context.parser(); - - parser.nextToken(); - failIfTokenIsNot(parser, XContentParser.Token.START_ARRAY); - - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_ARRAY; token = parser.nextToken()) { - DocumentParserContext nestedContext = context.createNestedContext(nestedObjectMapper); - parseFieldInferenceChunkElement(nestedContext, nestedObjectMapper, modelSettings); - } - } - - private static void parseFieldInferenceChunkElement( - DocumentParserContext context, - ObjectMapper objectMapper, - SemanticTextModelSettings modelSettings - ) throws IOException { - XContentParser parser = context.parser(); - DocumentParserContext childContext = context.createChildContext(objectMapper); - - failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT); - - Set visitedSubfields = new HashSet<>(); - for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { - failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME); - - String currentName = parser.currentName(); - visitedSubfields.add(currentName); - - Mapper childMapper = objectMapper.getMapper(currentName); - if (childMapper == null) { - logger.debug("Skipping indexing of unrecognized field name [" + currentName + "]"); - advancePastCurrentFieldName(parser); - continue; - } - - if (childMapper instanceof FieldMapper fieldMapper) { - parser.nextToken(); - fieldMapper.parse(childContext); - } else { - // This should never happen, but fail parsing if it does so that it's not a silent failure - throw new DocumentParsingException( - parser.getTokenLocation(), - Strings.format("Unhandled mapper type [%s] for field [%s]", childMapper.getClass(), currentName) - ); - } - } - - if (visitedSubfields.containsAll(REQUIRED_SUBFIELDS) == false) { - Set missingSubfields = REQUIRED_SUBFIELDS.stream() - .filter(s -> visitedSubfields.contains(s) == false) - .collect(Collectors.toSet()); - throw new DocumentParsingException(parser.getTokenLocation(), "Missing required subfields: " + missingSubfields); - } - } - - private static NestedObjectMapper createInferenceResultsObjectMapper( - DocumentParserContext context, - MapperBuilderContext mapperBuilderContext, - String fieldName, - SemanticTextModelSettings modelSettings - ) { - IndexVersion indexVersionCreated = context.indexSettings().getIndexVersionCreated(); - FieldMapper.Builder resultsBuilder; - if (modelSettings.taskType() == TaskType.SPARSE_EMBEDDING) { - resultsBuilder = new SparseVectorFieldMapper.Builder(INFERENCE_CHUNKS_RESULTS); - } else if (modelSettings.taskType() == TaskType.TEXT_EMBEDDING) { - DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( - INFERENCE_CHUNKS_RESULTS, - indexVersionCreated - ); - SimilarityMeasure similarity = modelSettings.similarity(); - if (similarity != null) { - switch (similarity) { - case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); - case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); - default -> throw new IllegalArgumentException( - "Unknown similarity measure for field [" + fieldName + "] in model settings: " + similarity - ); - } - } - Integer dimensions = modelSettings.dimensions(); - if (dimensions == null) { - throw new IllegalArgumentException("Model settings for field [" + fieldName + "] must contain dimensions"); - } - denseVectorMapperBuilder.dimensions(dimensions); - resultsBuilder = denseVectorMapperBuilder; - } else { - throw new IllegalArgumentException("Unknown task type for field [" + fieldName + "]: " + modelSettings.taskType()); - } - - TextFieldMapper.Builder textMapperBuilder = new TextFieldMapper.Builder( - INFERENCE_CHUNKS_TEXT, - indexVersionCreated, - context.indexAnalyzers() - ).index(false).store(false); - - NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder( - fieldName, - context.indexSettings().getIndexVersionCreated() - ); - nestedBuilder.add(resultsBuilder).add(textMapperBuilder); - - return nestedBuilder.build(mapperBuilderContext); - } - - private static void advancePastCurrentFieldName(XContentParser parser) throws IOException { - assert parser.currentToken() == XContentParser.Token.FIELD_NAME; - - XContentParser.Token token = parser.nextToken(); - if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) { - parser.skipChildren(); - } else if (token.isValue() == false && token != XContentParser.Token.VALUE_NULL) { - throw new DocumentParsingException(parser.getTokenLocation(), "Expected a START_* or VALUE_*, got " + token); - } - } - - private static void failIfTokenIsNot(XContentParser parser, XContentParser.Token expected) { - if (parser.currentToken() != expected) { - throw new DocumentParsingException( - parser.getTokenLocation(), - "Expected a " + expected.toString() + ", got " + parser.currentToken() - ); - } - } - - @Override - protected String contentType() { - return CONTENT_TYPE; - } - - @Override - public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { - return SourceLoader.SyntheticFieldLoader.NOTHING; - } - - public static void applyFieldInference( - Map inferenceMap, - String field, - Model model, - ChunkedInferenceServiceResults results - ) throws ElasticsearchException { - List> chunks = new ArrayList<>(); - if (results instanceof ChunkedSparseEmbeddingResults textExpansionResults) { - for (var chunk : textExpansionResults.getChunkedResults()) { - chunks.add(chunk.asMap()); - } - } else if (results instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { - for (var chunk : textEmbeddingResults.getChunks()) { - chunks.add(chunk.asMap()); - } - } else { - throw new ElasticsearchStatusException( - "Invalid inference results format for field [{}] with inference id [{}], got {}", - RestStatus.BAD_REQUEST, - field, - model.getInferenceEntityId(), - results.getWriteableName() - ); - } - Map fieldMap = new LinkedHashMap<>(); - fieldMap.putAll(new SemanticTextModelSettings(model).asMap()); - fieldMap.put(InferenceResultFieldMapper.RESULTS, chunks); - inferenceMap.put(field, fieldMap); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 4caa3d68ba877..deeea81a46d92 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -9,30 +9,54 @@ import org.apache.lucene.search.Query; import org.elasticsearch.common.Strings; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.analysis.IndexAnalyzers; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.FieldMapper; import org.elasticsearch.index.mapper.InferenceModelFieldType; import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.NestedObjectMapper; +import org.elasticsearch.index.mapper.ObjectMapper; import org.elasticsearch.index.mapper.SimpleMappedFieldType; import org.elasticsearch.index.mapper.SourceValueFetcher; +import org.elasticsearch.index.mapper.TextFieldMapper; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.xcontent.DeprecationHandler; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.support.MapXContentParser; import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; import java.util.Map; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.RESULTS; + /** - * A {@link FieldMapper} for semantic text fields. These fields have a model id reference, that is used for performing inference - * at ingestion and query time. - * For now, it is compatible with text expansion models only, but will be extended to support dense vector models as well. + * A {@link FieldMapper} for semantic text fields. + * These fields have a model id reference, that is used for performing inference at ingestion and query time. * This field mapper performs no indexing, as inference results will be included as a different field in the document source, and will - * be indexed using {@link InferenceResultFieldMapper}. + * be indexed using {@link InferenceMetadataFieldMapper}. */ public class SemanticTextFieldMapper extends FieldMapper { + private static final Logger logger = LogManager.getLogger(SemanticTextFieldMapper.class); public static final String CONTENT_TYPE = "semantic_text"; @@ -40,15 +64,47 @@ private static SemanticTextFieldMapper toType(FieldMapper in) { return (SemanticTextFieldMapper) in; } - public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n), notInMultiFields(CONTENT_TYPE)); + public static final TypeParser PARSER = new TypeParser( + (n, c) -> new Builder(n, c.indexVersionCreated(), c.getIndexAnalyzers()), + notInMultiFields(CONTENT_TYPE) + ); + + private final IndexVersion indexVersionCreated; + private final SemanticTextModelSettings modelSettings; + private final IndexAnalyzers indexAnalyzers; + private final NestedObjectMapper subMappers; - private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, CopyTo copyTo) { + private SemanticTextFieldMapper( + String simpleName, + MappedFieldType mappedFieldType, + CopyTo copyTo, + IndexVersion indexVersionCreated, + IndexAnalyzers indexAnalyzers, + SemanticTextModelSettings modelSettings, + NestedObjectMapper subMappers + ) { super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); + this.indexVersionCreated = indexVersionCreated; + this.indexAnalyzers = indexAnalyzers; + this.modelSettings = modelSettings; + this.subMappers = subMappers; + } + + @Override + public String name() { + return super.name(); + } + + @Override + public Iterator iterator() { + List subIterators = new ArrayList<>(); + subIterators.add(subMappers); + return subIterators.iterator(); } @Override public FieldMapper.Builder getMergeBuilder() { - return new Builder(simpleName()).init(this); + return new Builder(simpleName(), indexVersionCreated, indexAnalyzers).init(this); } @Override @@ -67,7 +123,17 @@ public SemanticTextFieldType fieldType() { return (SemanticTextFieldType) super.fieldType(); } + public SemanticTextModelSettings getModelSettings() { + return modelSettings; + } + + public NestedObjectMapper getNestedField() { + return subMappers; + } + public static class Builder extends FieldMapper.Builder { + private final IndexVersion indexVersionCreated; + private final IndexAnalyzers indexAnalyzers; private final Parameter modelId = Parameter.stringParam("model_id", false, m -> toType(m).fieldType().modelId, null) .addValidator(v -> { @@ -76,25 +142,84 @@ public static class Builder extends FieldMapper.Builder { } }); + @SuppressWarnings("unchecked") + private final Parameter modelSettings = new Parameter<>( + "model_settings", + true, + () -> null, + (name, context, node) -> { + if (node == null) { + return null; + } + try { + Map map = (Map) node; + XContentParser parser = new MapXContentParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + map, + XContentType.JSON + ); + return SemanticTextModelSettings.parse(parser); + } catch (Exception exc) { + throw new IllegalArgumentException(exc); + } + }, + m -> ((SemanticTextFieldMapper) m).modelSettings, + XContentBuilder::field, + Strings::toString + ).acceptsNull().setMergeValidator(SemanticTextModelSettings::checkCompatibility); + private final Parameter> meta = Parameter.metaParam(); - public Builder(String name) { + public Builder(String name, IndexVersion indexVersionCreated, IndexAnalyzers indexAnalyzers) { super(name); + this.indexVersionCreated = indexVersionCreated; + this.indexAnalyzers = indexAnalyzers; + } + + public Builder setModelId(String id) { + this.modelId.setValue(id); + return this; + } + + public Builder setModelSettings(SemanticTextModelSettings value) { + this.modelSettings.setValue(value); + return this; } @Override protected Parameter[] getParameters() { - return new Parameter[] { modelId, meta }; + return new Parameter[] { modelId, meta, modelSettings }; } @Override public SemanticTextFieldMapper build(MapperBuilderContext context) { - return new SemanticTextFieldMapper(name(), new SemanticTextFieldType(name(), modelId.getValue(), meta.getValue()), copyTo); + final String fullName = context.buildFullName(name()); + NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder(RESULTS, indexVersionCreated); + nestedBuilder.dynamic(ObjectMapper.Dynamic.FALSE); + TextFieldMapper.Builder textMapperBuilder = new TextFieldMapper.Builder( + INFERENCE_CHUNKS_TEXT, + indexVersionCreated, + indexAnalyzers + ).index(false).store(false); + if (modelSettings.get() != null) { + nestedBuilder.add(createInferenceMapperBuilder(INFERENCE_CHUNKS_RESULTS, modelSettings.get(), indexVersionCreated)); + } + nestedBuilder.add(textMapperBuilder); + var childContext = context.createChildContext(name(), ObjectMapper.Dynamic.FALSE); + return new SemanticTextFieldMapper( + name(), + new SemanticTextFieldType(fullName, modelId.getValue(), meta.getValue()), + copyTo, + indexVersionCreated, + indexAnalyzers, + modelSettings.getValue(), + nestedBuilder.build(childContext) + ); } } public static class SemanticTextFieldType extends SimpleMappedFieldType implements InferenceModelFieldType { - private final String modelId; public SemanticTextFieldType(String name, String modelId, Map meta) { @@ -127,4 +252,54 @@ public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext throw new IllegalArgumentException("[semantic_text] fields do not support sorting, scripting or aggregating"); } } + + private static Mapper.Builder createInferenceMapperBuilder( + String fieldName, + SemanticTextModelSettings modelSettings, + IndexVersion indexVersionCreated + ) { + return switch (modelSettings.taskType()) { + case SPARSE_EMBEDDING -> new SparseVectorFieldMapper.Builder(INFERENCE_CHUNKS_RESULTS); + case TEXT_EMBEDDING -> { + DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( + INFERENCE_CHUNKS_RESULTS, + indexVersionCreated + ); + SimilarityMeasure similarity = modelSettings.similarity(); + if (similarity != null) { + switch (similarity) { + case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); + case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); + default -> throw new IllegalArgumentException( + "Unknown similarity measure for field [" + fieldName + "] in model settings: " + similarity + ); + } + } + Integer dimensions = modelSettings.dimensions(); + denseVectorMapperBuilder.dimensions(dimensions); + yield denseVectorMapperBuilder; + } + default -> throw new IllegalArgumentException( + "Invalid [task_type] for [" + fieldName + "] in model settings: " + modelSettings.taskType().name() + ); + }; + } + + @Override + protected void checkIncomingMergeType(FieldMapper mergeWith) { + if (mergeWith instanceof SemanticTextFieldMapper other) { + if (other.modelSettings != null && other.modelSettings.inferenceId().equals(other.fieldType().getInferenceModel()) == false) { + throw new IllegalArgumentException( + "mapper [" + + name() + + "] refers to different model ids [" + + other.modelSettings.inferenceId() + + "] and [" + + other.fieldType().getInferenceModel() + + "]" + ); + } + } + super.checkIncomingMergeType(mergeWith); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java index 1b6bb22c0d6b5..8b49e420f16a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java @@ -7,11 +7,14 @@ package org.elasticsearch.xpack.inference.mapper; +import org.elasticsearch.index.mapper.FieldMapper; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; @@ -22,7 +25,7 @@ /** * Serialization class for specifying the settings of a model from semantic_text inference to field mapper. */ -public class SemanticTextModelSettings { +public class SemanticTextModelSettings implements ToXContentObject { public static final String NAME = "model_settings"; public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type"); @@ -98,4 +101,44 @@ public Integer dimensions() { public SimilarityMeasure similarity() { return similarity; } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TASK_TYPE_FIELD.getPreferredName(), taskType.toString()); + builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); + if (dimensions != null) { + builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); + } + if (similarity != null) { + builder.field(SIMILARITY_FIELD.getPreferredName(), similarity); + } + return builder.endObject(); + } + + public static boolean checkCompatibility( + SemanticTextModelSettings original, + SemanticTextModelSettings another, + FieldMapper.Conflicts conflicts + ) { + if (original == null) { + return true; + } + if (original != null && another == null) { + conflicts.addConflict("model_settings", "missing"); + } + if (original.inferenceId.equals(another.inferenceId) == false) { + conflicts.addConflict(INFERENCE_ID_FIELD.getPreferredName(), "values differ"); + } + if (original.taskType != another.taskType()) { + conflicts.addConflict(TASK_TYPE_FIELD.getPreferredName(), "values differ"); + } + if (original.dimensions != another.dimensions) { + conflicts.addConflict(DIMENSIONS_FIELD.getPreferredName(), "values differ"); + } + if (original.similarity != another.similarity) { + conflicts.addConflict(SIMILARITY_FIELD.getPreferredName(), "values differ"); + } + return conflicts.hasConflicts() == false; + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 7f3ffbe596543..a7af1443dc0ca 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -31,7 +31,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper; +import org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.junit.After; @@ -51,8 +51,8 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomSparseEmbeddings; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapperTests.randomTextEmbeddings; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomSparseEmbeddings; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapperTests.randomTextEmbeddings; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.any; @@ -291,11 +291,11 @@ private static BulkItemRequest[] randomBulkItemRequest( throw new AssertionError("Unknown task type " + taskType.name()); } model.putResult(text, results); - InferenceResultFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); + InferenceMetadataFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); } } Map expectedDocMap = new LinkedHashMap<>(docMap); - expectedDocMap.put(InferenceResultFieldMapper.NAME, inferenceResultsMap); + expectedDocMap.put(InferenceMetadataFieldMapper.NAME, inferenceResultsMap); return new BulkItemRequest[] { new BulkItemRequest(id, new IndexRequest("index").source(docMap)), new BulkItemRequest(id, new IndexRequest("index").source(expectedDocMap)) }; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java similarity index 66% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java index b5d75b528c6ab..b212ce6a269ef 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceResultFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapperTests.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.mapper; +import org.apache.lucene.document.FeatureField; +import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.Term; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; @@ -51,26 +53,28 @@ import java.util.Collection; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Consumer; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_RESULTS; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.INFERENCE_CHUNKS_TEXT; -import static org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper.RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_RESULTS; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.INFERENCE_CHUNKS_TEXT; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.RESULTS; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; -public class InferenceResultFieldMapperTests extends MetadataMapperTestCase { - private record SemanticTextInferenceResults(String fieldName, ChunkedInferenceServiceResults results, List text) {} +public class InferenceMetadataFieldMapperTests extends MetadataMapperTestCase { + private record SemanticTextInferenceResults(String fieldName, Model model, ChunkedInferenceServiceResults results, List text) {} - private record VisitedChildDocInfo(String path, int numChunks) {} + private record VisitedChildDocInfo(String path) {} private record SparseVectorSubfieldOptions(boolean include, boolean includeEmbedding, boolean includeIsTruncated) {} @Override protected String fieldName() { - return InferenceResultFieldMapper.NAME; + return InferenceMetadataFieldMapper.NAME; } @Override @@ -94,109 +98,127 @@ protected Collection getPlugins() { } public void testSuccessfulParse() throws IOException { - final String fieldName1 = randomAlphaOfLengthBetween(5, 15); - final String fieldName2 = randomAlphaOfLengthBetween(5, 15); - - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> { - addSemanticTextMapping(b, fieldName1, randomAlphaOfLength(8)); - addSemanticTextMapping(b, fieldName2, randomAlphaOfLength(8)); - })); - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of( - randomSemanticTextInferenceResults(fieldName1, List.of("a b", "c")), - randomSemanticTextInferenceResults(fieldName2, List.of("d e f")) + for (int depth = 1; depth < 4; depth++) { + final String fieldName1 = randomFieldName(depth); + final String fieldName2 = randomFieldName(depth + 1); + + Model model1 = randomModel(); + Model model2 = randomModel(); + XContentBuilder mapping = mapping(b -> { + addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId()); + addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId()); + }); + + MapperService mapperService = createMapperService(mapping); + DocumentMapper documentMapper = mapperService.documentMapper(); + ParsedDocument doc = documentMapper.parse( + source( + b -> addSemanticTextInferenceResults( + b, + List.of( + randomSemanticTextInferenceResults(fieldName1, model1, List.of("a b", "c")), + randomSemanticTextInferenceResults(fieldName2, model2, List.of("d e f")) + ) ) ) - ) - ); - - Set visitedChildDocs = new HashSet<>(); - Set expectedVisitedChildDocs = Set.of( - new VisitedChildDocInfo(fieldName1, 2), - new VisitedChildDocInfo(fieldName1, 1), - new VisitedChildDocInfo(fieldName2, 3) - ); - - List luceneDocs = doc.docs(); - assertEquals(4, luceneDocs.size()); - assertValidChildDoc(luceneDocs.get(0), doc.rootDoc(), visitedChildDocs); - assertValidChildDoc(luceneDocs.get(1), doc.rootDoc(), visitedChildDocs); - assertValidChildDoc(luceneDocs.get(2), doc.rootDoc(), visitedChildDocs); - assertEquals(doc.rootDoc(), luceneDocs.get(3)); - assertNull(luceneDocs.get(3).getParent()); - assertEquals(expectedVisitedChildDocs, visitedChildDocs); - - MapperService nestedMapperService = createMapperService(mapping(b -> { - addInferenceResultsNestedMapping(b, fieldName1); - addInferenceResultsNestedMapping(b, fieldName2); - })); - withLuceneIndex(nestedMapperService, iw -> iw.addDocuments(doc.docs()), reader -> { - NestedDocuments nested = new NestedDocuments( - nestedMapperService.mappingLookup(), - QueryBitSetProducer::new, - IndexVersion.current() - ); - LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); - - Set visitedNestedIdentities = new HashSet<>(); - Set expectedVisitedNestedIdentities = Set.of( - new SearchHit.NestedIdentity(fieldName1, 0, null), - new SearchHit.NestedIdentity(fieldName1, 1, null), - new SearchHit.NestedIdentity(fieldName2, 0, null) ); - assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); - assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); - - assertNull(leaf.advance(3)); - assertEquals(3, leaf.doc()); - assertEquals(3, leaf.rootDoc()); - assertNull(leaf.nestedIdentity()); - - IndexSearcher searcher = newSearcher(reader); - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a")), - 10 - ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a", "b")), - 10 - ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); + List luceneDocs = doc.docs(); + assertEquals(4, luceneDocs.size()); + for (int i = 0; i < 3; i++) { + assertEquals(doc.rootDoc(), luceneDocs.get(i).getParent()); } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName2, List.of("d")), - 10 + // nested docs are in reversed order + assertSparseFeatures(luceneDocs.get(0), fieldName1 + ".results.inference", 2); + assertSparseFeatures(luceneDocs.get(1), fieldName1 + ".results.inference", 1); + assertSparseFeatures(luceneDocs.get(2), fieldName2 + ".results.inference", 3); + assertEquals(doc.rootDoc(), luceneDocs.get(3)); + assertNull(luceneDocs.get(3).getParent()); + + withLuceneIndex(mapperService, iw -> iw.addDocuments(doc.docs()), reader -> { + NestedDocuments nested = new NestedDocuments( + mapperService.mappingLookup(), + QueryBitSetProducer::new, + IndexVersion.current() ); - assertEquals(1, topDocs.totalHits.value); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(nestedMapperService.mappingLookup().nestedLookup(), fieldName2, List.of("z")), - 10 + LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); + + Set visitedNestedIdentities = new HashSet<>(); + Set expectedVisitedNestedIdentities = Set.of( + new SearchHit.NestedIdentity(fieldName1 + "." + RESULTS, 0, null), + new SearchHit.NestedIdentity(fieldName1 + "." + RESULTS, 1, null), + new SearchHit.NestedIdentity(fieldName2 + "." + RESULTS, 0, null) ); - assertEquals(0, topDocs.totalHits.value); - } - }); + + assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); + assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); + + assertNull(leaf.advance(3)); + assertEquals(3, leaf.doc()); + assertEquals(3, leaf.rootDoc()); + assertNull(leaf.nestedIdentity()); + + IndexSearcher searcher = newSearcher(reader); + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName1 + "." + RESULTS, + List.of("a") + ), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName1 + "." + RESULTS, + List.of("a", "b") + ), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName2 + "." + RESULTS, + List.of("d") + ), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery( + mapperService.mappingLookup().nestedLookup(), + fieldName2 + "." + RESULTS, + List.of("z") + ), + 10 + ); + assertEquals(0, topDocs.totalHits.value); + } + }); + } } public void testMissingSubfields() throws IOException { final String fieldName = randomAlphaOfLengthBetween(5, 15); + final Model model = randomModel(); - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, fieldName, randomAlphaOfLength(8)))); + DocumentMapper documentMapper = createDocumentMapper( + mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId())) + ); { DocumentParsingException ex = expectThrows( @@ -206,7 +228,7 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), true, Map.of() @@ -224,7 +246,7 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), new SparseVectorSubfieldOptions(true, true, true), false, Map.of() @@ -242,7 +264,7 @@ public void testMissingSubfields() throws IOException { source( b -> addSemanticTextInferenceResults( b, - List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))), + List.of(randomSemanticTextInferenceResults(fieldName, model, List.of("a b"))), new SparseVectorSubfieldOptions(false, true, true), false, Map.of() @@ -259,15 +281,18 @@ public void testMissingSubfields() throws IOException { public void testExtraSubfields() throws IOException { final String fieldName = randomAlphaOfLengthBetween(5, 15); + final Model model = randomModel(); final List semanticTextInferenceResultsList = List.of( - randomSemanticTextInferenceResults(fieldName, List.of("a b")) + randomSemanticTextInferenceResults(fieldName, model, List.of("a b")) ); - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, fieldName, randomAlphaOfLength(8)))); + DocumentMapper documentMapper = createDocumentMapper( + mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId())) + ); Consumer checkParsedDocument = d -> { Set visitedChildDocs = new HashSet<>(); - Set expectedVisitedChildDocs = Set.of(new VisitedChildDocInfo(fieldName, 2)); + Set expectedVisitedChildDocs = Set.of(new VisitedChildDocInfo(fieldName + "." + RESULTS)); List luceneDocs = d.docs(); assertEquals(2, luceneDocs.size()); @@ -358,13 +383,18 @@ public void testMissingSemanticTextMapping() throws IOException { DocumentParsingException.class, DocumentParsingException.class, () -> documentMapper.parse( - source(b -> addSemanticTextInferenceResults(b, List.of(randomSemanticTextInferenceResults(fieldName, List.of("a b"))))) + source( + b -> addSemanticTextInferenceResults( + b, + List.of(randomSemanticTextInferenceResults(fieldName, randomModel(), List.of("a b"))) + ) + ) ) ); assertThat( ex.getMessage(), containsString( - Strings.format("Field [%s] is not registered as a %s field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) + Strings.format("Field [%s] is not registered as a [%s] field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) ) ); } @@ -400,8 +430,12 @@ public static ChunkedSparseEmbeddingResults randomSparseEmbeddings(List return new ChunkedSparseEmbeddingResults(chunks); } - private static SemanticTextInferenceResults randomSemanticTextInferenceResults(String semanticTextFieldName, List chunks) { - return new SemanticTextInferenceResults(semanticTextFieldName, randomSparseEmbeddings(chunks), chunks); + private static SemanticTextInferenceResults randomSemanticTextInferenceResults( + String semanticTextFieldName, + Model model, + List chunks + ) { + return new SemanticTextInferenceResults(semanticTextFieldName, model, randomSparseEmbeddings(chunks), chunks); } private static void addSemanticTextInferenceResults( @@ -425,12 +459,12 @@ private static void addSemanticTextInferenceResults( boolean includeTextSubfield, Map extraSubfields ) throws IOException { - Map inferenceResultsMap = new HashMap<>(); + Map inferenceResultsMap = new LinkedHashMap<>(); for (SemanticTextInferenceResults semanticTextInferenceResult : semanticTextInferenceResults) { - InferenceResultFieldMapper.applyFieldInference( + InferenceMetadataFieldMapper.applyFieldInference( inferenceResultsMap, semanticTextInferenceResult.fieldName, - randomModel(), + semanticTextInferenceResult.model, semanticTextInferenceResult.results ); Map optionsMap = (Map) inferenceResultsMap.get(semanticTextInferenceResult.fieldName); @@ -445,7 +479,18 @@ private static void addSemanticTextInferenceResults( entry.putAll(extraSubfields); } } - sourceBuilder.field(InferenceResultFieldMapper.NAME, inferenceResultsMap); + sourceBuilder.field(InferenceMetadataFieldMapper.NAME, inferenceResultsMap); + } + + private String randomFieldName(int numLevel) { + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < numLevel; i++) { + if (i > 0) { + builder.append('.'); + } + builder.append(randomAlphaOfLengthBetween(5, 15)); + } + return builder.toString(); } private static Model randomModel() { @@ -461,29 +506,6 @@ private static Model randomModel() { ); } - private static void addInferenceResultsNestedMapping(XContentBuilder mappingBuilder, String semanticTextFieldName) throws IOException { - mappingBuilder.startObject(semanticTextFieldName); - { - mappingBuilder.field("type", "nested"); - mappingBuilder.startObject("properties"); - { - mappingBuilder.startObject(INFERENCE_CHUNKS_RESULTS); - { - mappingBuilder.field("type", "sparse_vector"); - } - mappingBuilder.endObject(); - mappingBuilder.startObject(INFERENCE_CHUNKS_TEXT); - { - mappingBuilder.field("type", "text"); - mappingBuilder.field("index", false); - } - mappingBuilder.endObject(); - } - mappingBuilder.endObject(); - } - mappingBuilder.endObject(); - } - private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLookup, String path, List tokens) { NestedObjectMapper mapper = nestedLookup.getNestedMappers().get(path); assertNotNull(mapper); @@ -503,12 +525,10 @@ private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLook private static void assertValidChildDoc( LuceneDocument childDoc, LuceneDocument expectedParent, - Set visitedChildDocs + Collection visitedChildDocs ) { assertEquals(expectedParent, childDoc.getParent()); - visitedChildDocs.add( - new VisitedChildDocInfo(childDoc.getPath(), childDoc.getFields(childDoc.getPath() + "." + INFERENCE_CHUNKS_RESULTS).size()) - ); + visitedChildDocs.add(new VisitedChildDocInfo(childDoc.getPath())); } private static void assertChildLeafNestedDocument( @@ -524,4 +544,15 @@ private static void assertChildLeafNestedDocument( assertNotNull(leaf.nestedIdentity()); visitedNestedIdentities.add(leaf.nestedIdentity()); } + + private static void assertSparseFeatures(LuceneDocument doc, String fieldName, int expectedCount) { + int count = 0; + for (IndexableField field : doc.getFields()) { + if (field instanceof FeatureField featureField) { + assertThat(featureField.name(), equalTo(fieldName)); + ++count; + } + } + assertThat(count, equalTo(expectedCount)); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index a3a705c9cc902..274ef346e27e4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MapperParsingException; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.MapperTestCase; @@ -82,6 +83,11 @@ protected void minimalMapping(XContentBuilder b) throws IOException { b.field("type", "semantic_text").field("model_id", "test_model"); } + @Override + protected String minimalIsInvalidRoutingPathErrorMessage(Mapper mapper) { + return "cannot have nested fields when index is in [index.mode=time_series]"; + } + @Override protected Object getSampleValueForDocument() { return "value"; diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml index 2c69f49218091..6744b04014446 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapper.yml @@ -86,7 +86,7 @@ setup: _inference: dense_field: model_settings: - inference_id: sparse-inference-id + inference_id: dense-inference-id task_type: text_embedding dimensions: 5 similarity: cosine @@ -144,7 +144,7 @@ setup: _inference: dense_field: model_settings: - inference_id: sparse-inference-id + inference_id: dense-inference-id task_type: text_embedding results: - text: "inference test"