diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 08d11f7bd41f2..ca1d8f75542b4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -129,8 +129,7 @@ protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeCont var semanticMergeWith = (SemanticTextFieldMapper) mergeWith; var context = mapperMergeContext.createChildContext(mergeWith.simpleName(), ObjectMapper.Dynamic.FALSE); var inferenceField = inferenceFieldBuilder.apply(context.getMapperBuilderContext()); - var childContext = context.createChildContext(inferenceField.simpleName(), ObjectMapper.Dynamic.FALSE); - var mergedInferenceField = inferenceField.merge(semanticMergeWith.fieldType().getInferenceField(), childContext); + var mergedInferenceField = inferenceField.merge(semanticMergeWith.fieldType().getInferenceField(), context); inferenceFieldBuilder = c -> mergedInferenceField; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index a6f0fa83eab37..9392a3184e3ac 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -19,12 +19,16 @@ import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.QueryBitSetProducer; import org.apache.lucene.search.join.ScoreMode; +import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.common.lucene.search.Queries; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.DocumentParsingException; +import org.elasticsearch.index.mapper.FieldMapper; import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.LuceneDocument; import org.elasticsearch.index.mapper.MappedFieldType; @@ -35,6 +39,7 @@ import org.elasticsearch.index.mapper.NestedLookup; import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ParsedDocument; +import org.elasticsearch.index.mapper.SourceToParse; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; @@ -168,6 +173,45 @@ public void testUpdatesToInferenceIdNotSupported() throws IOException { assertThat(e.getMessage(), containsString("Cannot update parameter [inference_id] from [test_model] to [another_model]")); } + public void testDynamicUpdate() throws IOException { + MapperService mapperService = createMapperService(mapping(b -> {})); + mapperService.merge( + "_doc", + new CompressedXContent( + Strings.toString(PutMappingRequest.simpleMapping("semantic", "type=semantic_text,inference_id=test_service")) + ), + MapperService.MergeReason.MAPPING_UPDATE + ); + String source = """ + { + "semantic": { + "inference": { + "inference_id": "test_service", + "model_settings": { + "task_type": "SPARSE_EMBEDDING" + }, + "chunks": [ + { + "embeddings": { + "feature_0": 1 + }, + "text": "feature_0" + } + ] + } + } + } + """; + SourceToParse sourceToParse = new SourceToParse("test", new BytesArray(source), XContentType.JSON); + ParsedDocument parsedDocument = mapperService.documentMapper().parse(sourceToParse); + mapperService.merge( + "_doc", + parsedDocument.dynamicMappingsUpdate().toCompressedXContent(), + MapperService.MergeReason.MAPPING_UPDATE + ); + assertSemanticTextField(mapperService, "semantic", true); + } + public void testUpdateModelSettings() throws IOException { for (int depth = 1; depth < 5; depth++) { String fieldName = randomFieldName(depth); @@ -270,6 +314,7 @@ static void assertSemanticTextField(MapperService mapperService, String fieldNam .getNestedMappers() .get(getChunksFieldName(fieldName)); assertThat(chunksMapper, equalTo(semanticFieldMapper.fieldType().getChunksField())); + assertThat(chunksMapper.name(), equalTo(getChunksFieldName(fieldName))); Mapper textMapper = chunksMapper.getMapper(CHUNKED_TEXT_FIELD.getPreferredName()); assertNotNull(textMapper); assertThat(textMapper, instanceOf(KeywordFieldMapper.class)); @@ -278,11 +323,15 @@ static void assertSemanticTextField(MapperService mapperService, String fieldNam assertFalse(textFieldMapper.fieldType().hasDocValues()); if (expectedModelSettings) { assertNotNull(semanticFieldMapper.fieldType().getModelSettings()); - Mapper inferenceMapper = chunksMapper.getMapper(CHUNKED_EMBEDDINGS_FIELD.getPreferredName()); - assertNotNull(inferenceMapper); + Mapper embeddingsMapper = chunksMapper.getMapper(CHUNKED_EMBEDDINGS_FIELD.getPreferredName()); + assertNotNull(embeddingsMapper); + assertThat(embeddingsMapper, instanceOf(FieldMapper.class)); + FieldMapper embeddingsFieldMapper = (FieldMapper) embeddingsMapper; + assertTrue(embeddingsFieldMapper.fieldType() == mapperService.mappingLookup().getFieldType(getEmbeddingsFieldName(fieldName))); + assertThat(embeddingsMapper.name(), equalTo(getEmbeddingsFieldName(fieldName))); switch (semanticFieldMapper.fieldType().getModelSettings().taskType()) { - case SPARSE_EMBEDDING -> assertThat(inferenceMapper, instanceOf(SparseVectorFieldMapper.class)); - case TEXT_EMBEDDING -> assertThat(inferenceMapper, instanceOf(DenseVectorFieldMapper.class)); + case SPARSE_EMBEDDING -> assertThat(embeddingsMapper, instanceOf(SparseVectorFieldMapper.class)); + case TEXT_EMBEDDING -> assertThat(embeddingsMapper, instanceOf(DenseVectorFieldMapper.class)); default -> throw new AssertionError("Invalid task type"); } } else {