diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java index c11e8f8b82bf0..20315cb43e2a0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/InferenceMetadataFieldMapper.java @@ -43,6 +43,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashMap; @@ -158,14 +159,13 @@ public InferenceMetadataFieldMapper() { protected void parseCreateField(DocumentParserContext context) throws IOException { XContentParser parser = context.parser(); failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.START_OBJECT); - MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false); boolean origWithLeafObject = context.path().isWithinLeafObject(); try { // make sure that we don't expand dots in field names while parsing context.path().setWithinLeafObject(true); for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) { failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.FIELD_NAME); - parseSingleField(context, mapperBuilderContext); + parseSingleField(context); } } finally { context.path().setWithinLeafObject(origWithLeafObject); @@ -174,59 +174,56 @@ protected void parseCreateField(DocumentParserContext context) throws IOExceptio private NestedObjectMapper updateSemanticTextFieldMapper( DocumentParserContext docContext, - MapperBuilderContext mapperBuilderContext, - ObjectMapper parent, - SemanticTextFieldMapper original, - String inferenceId, - SemanticTextModelSettings modelSettings, + SemanticTextMapperContext semanticFieldContext, + String newInferenceId, + SemanticTextModelSettings newModelSettings, XContentLocation xContentLocation ) { - if (inferenceId.equals(original.fieldType().getInferenceId()) == false) { + final String fullFieldName = semanticFieldContext.mapper.fieldType().name(); + final String inferenceId = semanticFieldContext.mapper.fieldType().getInferenceId(); + if (newInferenceId.equals(inferenceId) == false) { throw new DocumentParsingException( xContentLocation, Strings.format( "The configured %s [%s] for field [%s] doesn't match the %s [%s] reported in the document.", INFERENCE_ID, inferenceId, - original.name(), + fullFieldName, INFERENCE_ID, - original.fieldType().getInferenceId() + newInferenceId ) ); } - if (modelSettings.taskType() == TaskType.TEXT_EMBEDDING && modelSettings.dimensions() == null) { + if (newModelSettings.taskType() == TaskType.TEXT_EMBEDDING && newModelSettings.dimensions() == null) { throw new DocumentParsingException( xContentLocation, - "Model settings for field [" + original.name() + "] must contain dimensions" + "Model settings for field [" + fullFieldName + "] must contain dimensions" ); } - if (original.getModelSettings() == null) { - if (parent != docContext.root()) { - mapperBuilderContext = mapperBuilderContext.createChildContext(parent.name(), ObjectMapper.Dynamic.FALSE); - } + if (semanticFieldContext.mapper.getModelSettings() == null) { SemanticTextFieldMapper newMapper = new SemanticTextFieldMapper.Builder( - original.simpleName(), + semanticFieldContext.mapper.simpleName(), docContext.indexSettings().getIndexVersionCreated() - ).setInferenceId(original.fieldType().getInferenceId()).setModelSettings(modelSettings).build(mapperBuilderContext); + ).setInferenceId(newInferenceId).setModelSettings(newModelSettings).build(semanticFieldContext.context); docContext.addDynamicMapper(newMapper); return newMapper.getSubMappers(); } else { - SemanticTextFieldMapper.Conflicts conflicts = new Conflicts(original.name()); - SemanticTextFieldMapper.canMergeModelSettings(original.getModelSettings(), modelSettings, conflicts); + SemanticTextFieldMapper.Conflicts conflicts = new Conflicts(fullFieldName); + SemanticTextFieldMapper.canMergeModelSettings(semanticFieldContext.mapper.getModelSettings(), newModelSettings, conflicts); try { conflicts.check(); } catch (Exception exc) { throw new DocumentParsingException(xContentLocation, "Incompatible model_settings", exc); } } - return original.getSubMappers(); + return semanticFieldContext.mapper.getSubMappers(); } - private void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) throws IOException { + private void parseSingleField(DocumentParserContext context) throws IOException { XContentParser parser = context.parser(); String fieldName = parser.currentName(); - var res = findMapper(context.mappingLookup().getMapping().getRoot(), fieldName); - if (res == null || res.mapper instanceof SemanticTextFieldMapper == false) { + SemanticTextMapperContext builderContext = createSemanticFieldContext(context, fieldName); + if (builderContext == null) { throw new DocumentParsingException( parser.getTokenLocation(), Strings.format("Field [%s] is not registered as a [%s] field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE) @@ -276,15 +273,7 @@ private void parseSingleField(DocumentParserContext context, MapperBuilderContex ); } - var nestedObjectMapper = updateSemanticTextFieldMapper( - context, - mapperBuilderContext, - res.parent, - (SemanticTextFieldMapper) res.mapper, - inferenceId, - modelSettings, - xContentLocation - ); + var nestedObjectMapper = updateSemanticTextFieldMapper(context, builderContext, inferenceId, modelSettings, xContentLocation); // we know the model settings, so we can (re) parse the results array now XContentParser subParser = new MapXContentParser( @@ -419,17 +408,40 @@ public static void applyFieldInference( inferenceMap.put(field, fieldMap); } - record MapperAndParent(ObjectMapper parent, Mapper mapper) {} + record SemanticTextMapperContext(MapperBuilderContext context, SemanticTextFieldMapper mapper) {} + + /** + * Returns the {@link SemanticTextFieldMapper} associated with the provided {@code fullName} + * and the {@link MapperBuilderContext} that was used to build it. + * If the field is not found or is of the wrong type, this method returns {@code null}. + */ + static SemanticTextMapperContext createSemanticFieldContext(DocumentParserContext docContext, String fullName) { + ObjectMapper rootMapper = docContext.mappingLookup().getMapping().getRoot(); + return createSemanticFieldContext(MapperBuilderContext.root(false, false), rootMapper, fullName, fullName.split("\\.")); + } - static MapperAndParent findMapper(ObjectMapper mapper, String fullPath) { - String[] pathElements = fullPath.split("\\."); - for (int i = 0; i < pathElements.length - 1; i++) { - Mapper next = mapper.getMapper(pathElements[i]); - if (next == null || next instanceof ObjectMapper == false) { + static SemanticTextMapperContext createSemanticFieldContext( + MapperBuilderContext mapperContext, + ObjectMapper objectMapper, + String fullName, + String[] paths + ) { + Mapper mapper = objectMapper.getMapper(paths[0]); + if (mapper instanceof ObjectMapper newObjectMapper) { + mapperContext = mapperContext.createChildContext(paths[0], ObjectMapper.Dynamic.FALSE); + return createSemanticFieldContext(mapperContext, newObjectMapper, fullName, Arrays.copyOfRange(paths, 1, paths.length)); + } else if (mapper instanceof SemanticTextFieldMapper semanticMapper) { + return new SemanticTextMapperContext(mapperContext, semanticMapper); + } else { + if (mapper == null || paths.length == 1) { return null; } - mapper = (ObjectMapper) next; + // check if the semantic field is defined within a multi-field + Mapper fieldMapper = objectMapper.getMapper(String.join(".", Arrays.asList(paths))); + if (fieldMapper instanceof SemanticTextFieldMapper semanticMapper) { + return new SemanticTextMapperContext(mapperContext, semanticMapper); + } } - return new MapperAndParent(mapper, mapper.getMapper(pathElements[pathElements.length - 1])); + return null; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 551b5f73fe27e..e9b5a788256d0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MapperBuilderContext; import org.elasticsearch.index.mapper.MapperParsingException; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.MapperTestCase; @@ -31,7 +32,7 @@ import java.util.List; import static java.util.Collections.singletonList; -import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.findMapper; +import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.createSemanticFieldContext; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -219,7 +220,12 @@ public void testUpdateModelSettings() throws IOException { } static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) { - var res = findMapper(mapperService.mappingLookup().getMapping().getRoot(), fieldName); + InferenceMetadataFieldMapper.SemanticTextMapperContext res = createSemanticFieldContext( + MapperBuilderContext.root(false, false), + mapperService.mappingLookup().getMapping().getRoot(), + fieldName, + fieldName.split("\\.") + ); Mapper mapper = res.mapper(); assertNotNull(mapper); assertThat(mapper, instanceOf(SemanticTextFieldMapper.class));