Skip to content

Commit

Permalink
Fix the mapper builder context when updating the semantic text field …
Browse files Browse the repository at this point in the history
…definition
  • Loading branch information
jimczi committed Mar 21, 2024
1 parent b3fb5d3 commit 8ddc37f
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
Expand Down Expand Up @@ -158,14 +159,13 @@ public InferenceMetadataFieldMapper() {
protected void parseCreateField(DocumentParserContext context) throws IOException {
XContentParser parser = context.parser();
failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.START_OBJECT);
MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false);
boolean origWithLeafObject = context.path().isWithinLeafObject();
try {
// make sure that we don't expand dots in field names while parsing
context.path().setWithinLeafObject(true);
for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) {
failIfTokenIsNot(parser.getTokenLocation(), parser, XContentParser.Token.FIELD_NAME);
parseSingleField(context, mapperBuilderContext);
parseSingleField(context);
}
} finally {
context.path().setWithinLeafObject(origWithLeafObject);
Expand All @@ -174,59 +174,56 @@ protected void parseCreateField(DocumentParserContext context) throws IOExceptio

private NestedObjectMapper updateSemanticTextFieldMapper(
DocumentParserContext docContext,
MapperBuilderContext mapperBuilderContext,
ObjectMapper parent,
SemanticTextFieldMapper original,
String inferenceId,
SemanticTextModelSettings modelSettings,
SemanticTextMapperContext semanticFieldContext,
String newInferenceId,
SemanticTextModelSettings newModelSettings,
XContentLocation xContentLocation
) {
if (inferenceId.equals(original.fieldType().getInferenceId()) == false) {
final String fullFieldName = semanticFieldContext.mapper.fieldType().name();
final String inferenceId = semanticFieldContext.mapper.fieldType().getInferenceId();
if (newInferenceId.equals(inferenceId) == false) {
throw new DocumentParsingException(
xContentLocation,
Strings.format(
"The configured %s [%s] for field [%s] doesn't match the %s [%s] reported in the document.",
INFERENCE_ID,
inferenceId,
original.name(),
fullFieldName,
INFERENCE_ID,
original.fieldType().getInferenceId()
newInferenceId
)
);
}
if (modelSettings.taskType() == TaskType.TEXT_EMBEDDING && modelSettings.dimensions() == null) {
if (newModelSettings.taskType() == TaskType.TEXT_EMBEDDING && newModelSettings.dimensions() == null) {
throw new DocumentParsingException(
xContentLocation,
"Model settings for field [" + original.name() + "] must contain dimensions"
"Model settings for field [" + fullFieldName + "] must contain dimensions"
);
}
if (original.getModelSettings() == null) {
if (parent != docContext.root()) {
mapperBuilderContext = mapperBuilderContext.createChildContext(parent.name(), ObjectMapper.Dynamic.FALSE);
}
if (semanticFieldContext.mapper.getModelSettings() == null) {
SemanticTextFieldMapper newMapper = new SemanticTextFieldMapper.Builder(
original.simpleName(),
semanticFieldContext.mapper.simpleName(),
docContext.indexSettings().getIndexVersionCreated()
).setInferenceId(original.fieldType().getInferenceId()).setModelSettings(modelSettings).build(mapperBuilderContext);
).setInferenceId(newInferenceId).setModelSettings(newModelSettings).build(semanticFieldContext.context);
docContext.addDynamicMapper(newMapper);
return newMapper.getSubMappers();
} else {
SemanticTextFieldMapper.Conflicts conflicts = new Conflicts(original.name());
SemanticTextFieldMapper.canMergeModelSettings(original.getModelSettings(), modelSettings, conflicts);
SemanticTextFieldMapper.Conflicts conflicts = new Conflicts(fullFieldName);
SemanticTextFieldMapper.canMergeModelSettings(semanticFieldContext.mapper.getModelSettings(), newModelSettings, conflicts);
try {
conflicts.check();
} catch (Exception exc) {
throw new DocumentParsingException(xContentLocation, "Incompatible model_settings", exc);
}
}
return original.getSubMappers();
return semanticFieldContext.mapper.getSubMappers();
}

private void parseSingleField(DocumentParserContext context, MapperBuilderContext mapperBuilderContext) throws IOException {
private void parseSingleField(DocumentParserContext context) throws IOException {
XContentParser parser = context.parser();
String fieldName = parser.currentName();
var res = findMapper(context.mappingLookup().getMapping().getRoot(), fieldName);
if (res == null || res.mapper instanceof SemanticTextFieldMapper == false) {
SemanticTextMapperContext builderContext = createSemanticFieldContext(context, fieldName);
if (builderContext == null) {
throw new DocumentParsingException(
parser.getTokenLocation(),
Strings.format("Field [%s] is not registered as a [%s] field type", fieldName, SemanticTextFieldMapper.CONTENT_TYPE)
Expand Down Expand Up @@ -276,15 +273,7 @@ private void parseSingleField(DocumentParserContext context, MapperBuilderContex
);
}

var nestedObjectMapper = updateSemanticTextFieldMapper(
context,
mapperBuilderContext,
res.parent,
(SemanticTextFieldMapper) res.mapper,
inferenceId,
modelSettings,
xContentLocation
);
var nestedObjectMapper = updateSemanticTextFieldMapper(context, builderContext, inferenceId, modelSettings, xContentLocation);

// we know the model settings, so we can (re) parse the results array now
XContentParser subParser = new MapXContentParser(
Expand Down Expand Up @@ -419,17 +408,40 @@ public static void applyFieldInference(
inferenceMap.put(field, fieldMap);
}

record MapperAndParent(ObjectMapper parent, Mapper mapper) {}
record SemanticTextMapperContext(MapperBuilderContext context, SemanticTextFieldMapper mapper) {}

/**
* Returns the {@link SemanticTextFieldMapper} associated with the provided {@code fullName}
* and the {@link MapperBuilderContext} that was used to build it.
* If the field is not found or is of the wrong type, this method returns {@code null}.
*/
static SemanticTextMapperContext createSemanticFieldContext(DocumentParserContext docContext, String fullName) {
ObjectMapper rootMapper = docContext.mappingLookup().getMapping().getRoot();
return createSemanticFieldContext(MapperBuilderContext.root(false, false), rootMapper, fullName, fullName.split("\\."));
}

static MapperAndParent findMapper(ObjectMapper mapper, String fullPath) {
String[] pathElements = fullPath.split("\\.");
for (int i = 0; i < pathElements.length - 1; i++) {
Mapper next = mapper.getMapper(pathElements[i]);
if (next == null || next instanceof ObjectMapper == false) {
static SemanticTextMapperContext createSemanticFieldContext(
MapperBuilderContext mapperContext,
ObjectMapper objectMapper,
String fullName,
String[] paths
) {
Mapper mapper = objectMapper.getMapper(paths[0]);
if (mapper instanceof ObjectMapper newObjectMapper) {
mapperContext = mapperContext.createChildContext(paths[0], ObjectMapper.Dynamic.FALSE);
return createSemanticFieldContext(mapperContext, newObjectMapper, fullName, Arrays.copyOfRange(paths, 1, paths.length));
} else if (mapper instanceof SemanticTextFieldMapper semanticMapper) {
return new SemanticTextMapperContext(mapperContext, semanticMapper);
} else {
if (mapper == null || paths.length == 1) {
return null;
}
mapper = (ObjectMapper) next;
// check if the semantic field is defined within a multi-field
Mapper fieldMapper = objectMapper.getMapper(String.join(".", Arrays.asList(paths)));
if (fieldMapper instanceof SemanticTextFieldMapper semanticMapper) {
return new SemanticTextMapperContext(mapperContext, semanticMapper);
}
}
return new MapperAndParent(mapper, mapper.getMapper(pathElements[pathElements.length - 1]));
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.index.mapper.KeywordFieldMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.Mapper;
import org.elasticsearch.index.mapper.MapperBuilderContext;
import org.elasticsearch.index.mapper.MapperParsingException;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.MapperTestCase;
Expand All @@ -31,7 +32,7 @@
import java.util.List;

import static java.util.Collections.singletonList;
import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.findMapper;
import static org.elasticsearch.xpack.inference.mapper.InferenceMetadataFieldMapper.createSemanticFieldContext;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -219,7 +220,12 @@ public void testUpdateModelSettings() throws IOException {
}

static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) {
var res = findMapper(mapperService.mappingLookup().getMapping().getRoot(), fieldName);
InferenceMetadataFieldMapper.SemanticTextMapperContext res = createSemanticFieldContext(
MapperBuilderContext.root(false, false),
mapperService.mappingLookup().getMapping().getRoot(),
fieldName,
fieldName.split("\\.")
);
Mapper mapper = res.mapper();
assertNotNull(mapper);
assertThat(mapper, instanceOf(SemanticTextFieldMapper.class));
Expand Down

0 comments on commit 8ddc37f

Please sign in to comment.