Skip to content

Commit

Permalink
Fix redundant path in semantic text field merging
Browse files Browse the repository at this point in the history
Fix the merging of the object field within the semantic_text mapper, the merge context should be set at the parent level (was at the object/child level before merging).
  • Loading branch information
jimczi committed Apr 16, 2024
1 parent fe1fca0 commit 5440916
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeCont
var semanticMergeWith = (SemanticTextFieldMapper) mergeWith;
var context = mapperMergeContext.createChildContext(mergeWith.simpleName(), ObjectMapper.Dynamic.FALSE);
var inferenceField = inferenceFieldBuilder.apply(context.getMapperBuilderContext());
var childContext = context.createChildContext(inferenceField.simpleName(), ObjectMapper.Dynamic.FALSE);
var mergedInferenceField = inferenceField.merge(semanticMergeWith.fieldType().getInferenceField(), childContext);
var mergedInferenceField = inferenceField.merge(semanticMergeWith.fieldType().getInferenceField(), context);
inferenceFieldBuilder = c -> mergedInferenceField;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.QueryBitSetProducer;
import org.apache.lucene.search.join.ScoreMode;
import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.compress.CompressedXContent;
import org.elasticsearch.common.lucene.search.Queries;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.mapper.DocumentMapper;
import org.elasticsearch.index.mapper.DocumentParsingException;
import org.elasticsearch.index.mapper.FieldMapper;
import org.elasticsearch.index.mapper.KeywordFieldMapper;
import org.elasticsearch.index.mapper.LuceneDocument;
import org.elasticsearch.index.mapper.MappedFieldType;
Expand All @@ -35,6 +39,7 @@
import org.elasticsearch.index.mapper.NestedLookup;
import org.elasticsearch.index.mapper.NestedObjectMapper;
import org.elasticsearch.index.mapper.ParsedDocument;
import org.elasticsearch.index.mapper.SourceToParse;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
import org.elasticsearch.index.search.ESToParentBlockJoinQuery;
Expand Down Expand Up @@ -168,6 +173,45 @@ public void testUpdatesToInferenceIdNotSupported() throws IOException {
assertThat(e.getMessage(), containsString("Cannot update parameter [inference_id] from [test_model] to [another_model]"));
}

public void testDynamicUpdate() throws IOException {
MapperService mapperService = createMapperService(mapping(b -> {}));
mapperService.merge(
"_doc",
new CompressedXContent(
Strings.toString(PutMappingRequest.simpleMapping("semantic", "type=semantic_text,inference_id=test_service"))
),
MapperService.MergeReason.MAPPING_UPDATE
);
String source = """
{
"semantic": {
"inference": {
"inference_id": "test_service",
"model_settings": {
"task_type": "SPARSE_EMBEDDING"
},
"chunks": [
{
"embeddings": {
"feature_0": 1
},
"text": "feature_0"
}
]
}
}
}
""";
SourceToParse sourceToParse = new SourceToParse("test", new BytesArray(source), XContentType.JSON);
ParsedDocument parsedDocument = mapperService.documentMapper().parse(sourceToParse);
mapperService.merge(
"_doc",
parsedDocument.dynamicMappingsUpdate().toCompressedXContent(),
MapperService.MergeReason.MAPPING_UPDATE
);
assertSemanticTextField(mapperService, "semantic", true);
}

public void testUpdateModelSettings() throws IOException {
for (int depth = 1; depth < 5; depth++) {
String fieldName = randomFieldName(depth);
Expand Down Expand Up @@ -270,6 +314,7 @@ static void assertSemanticTextField(MapperService mapperService, String fieldNam
.getNestedMappers()
.get(getChunksFieldName(fieldName));
assertThat(chunksMapper, equalTo(semanticFieldMapper.fieldType().getChunksField()));
assertThat(chunksMapper.name(), equalTo(getChunksFieldName(fieldName)));
Mapper textMapper = chunksMapper.getMapper(CHUNKED_TEXT_FIELD.getPreferredName());
assertNotNull(textMapper);
assertThat(textMapper, instanceOf(KeywordFieldMapper.class));
Expand All @@ -278,11 +323,15 @@ static void assertSemanticTextField(MapperService mapperService, String fieldNam
assertFalse(textFieldMapper.fieldType().hasDocValues());
if (expectedModelSettings) {
assertNotNull(semanticFieldMapper.fieldType().getModelSettings());
Mapper inferenceMapper = chunksMapper.getMapper(CHUNKED_EMBEDDINGS_FIELD.getPreferredName());
assertNotNull(inferenceMapper);
Mapper embeddingsMapper = chunksMapper.getMapper(CHUNKED_EMBEDDINGS_FIELD.getPreferredName());
assertNotNull(embeddingsMapper);
assertThat(embeddingsMapper, instanceOf(FieldMapper.class));
FieldMapper embeddingsFieldMapper = (FieldMapper) embeddingsMapper;
assertTrue(embeddingsFieldMapper.fieldType() == mapperService.mappingLookup().getFieldType(getEmbeddingsFieldName(fieldName)));
assertThat(embeddingsMapper.name(), equalTo(getEmbeddingsFieldName(fieldName)));
switch (semanticFieldMapper.fieldType().getModelSettings().taskType()) {
case SPARSE_EMBEDDING -> assertThat(inferenceMapper, instanceOf(SparseVectorFieldMapper.class));
case TEXT_EMBEDDING -> assertThat(inferenceMapper, instanceOf(DenseVectorFieldMapper.class));
case SPARSE_EMBEDDING -> assertThat(embeddingsMapper, instanceOf(SparseVectorFieldMapper.class));
case TEXT_EMBEDDING -> assertThat(embeddingsMapper, instanceOf(DenseVectorFieldMapper.class));
default -> throw new AssertionError("Invalid task type");
}
} else {
Expand Down

0 comments on commit 5440916

Please sign in to comment.