Skip to content

Commit

Permalink
First working version of parsing nested inference results
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Nov 23, 2023
1 parent be33d25 commit b1d5bd0
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.elasticsearch.index.analysis.CharFilterFactory;
import org.elasticsearch.index.analysis.TokenizerFactory;
import org.elasticsearch.index.mapper.Mapper;
import org.elasticsearch.index.mapper.MetadataFieldMapper;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.indices.AssociatedIndexDescriptor;
import org.elasticsearch.indices.SystemIndexDescriptor;
Expand Down Expand Up @@ -364,6 +365,7 @@
import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTaskExecutor;
import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutor;
import org.elasticsearch.xpack.ml.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.ml.mapper.SemanticTextInferenceResultFieldMapper;
import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
Expand Down Expand Up @@ -2284,13 +2286,13 @@ public Map<String, Mapper.TypeParser> getMappers() {
);
}

// @Override
// public Map<String, MetadataFieldMapper.TypeParser> getMetadataMappers() {
// return Map.of(
// SemanticTextInferenceResultFieldMapper.CONTENT_TYPE,
// SemanticTextInferenceResultFieldMapper.PARSER
// );
// }
@Override
public Map<String, MetadataFieldMapper.TypeParser> getMetadataMappers() {
return Map.of(
SemanticTextInferenceResultFieldMapper.CONTENT_TYPE,
SemanticTextInferenceResultFieldMapper.PARSER
);
}

@Override
public Optional<Pipeline> getIngestPipeline(IndexMetadata indexMetadata, Processor.Parameters parameters) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,21 @@
import org.elasticsearch.index.mapper.Mapper;
import org.elasticsearch.index.mapper.MapperBuilderContext;
import org.elasticsearch.index.mapper.MetadataFieldMapper;
import org.elasticsearch.index.mapper.NestedObjectMapper;
import org.elasticsearch.index.mapper.SourceLoader;
import org.elasticsearch.index.mapper.SourceValueFetcher;
import org.elasticsearch.index.mapper.TextFieldMapper;
import org.elasticsearch.index.mapper.TextSearchInfo;
import org.elasticsearch.index.mapper.ValueFetcher;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.SparseVectorFieldType;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.HashSet;
import java.util.Set;

public class SemanticTextInferenceResultFieldMapper extends MetadataFieldMapper {

Expand All @@ -47,7 +49,6 @@ private static SemanticTextInferenceResultFieldMapper toType(FieldMapper in) {
public static class SemanticTextInferenceFieldType extends MappedFieldType {

public static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType();
private SparseVectorFieldType sparseVectorFieldType;

public SemanticTextInferenceFieldType() {
super(NAME, true, false, false, TextSearchInfo.NONE, Collections.emptyMap());
Expand All @@ -65,7 +66,7 @@ public ValueFetcher valueFetcher(SearchExecutionContext context, String format)

@Override
public Query termQuery(Object value, SearchExecutionContext context) {
return sparseVectorFieldType.termQuery(value, context);
return null;
}
}

Expand All @@ -78,40 +79,82 @@ public void parse(DocumentParserContext context) throws IOException {

if (context.parser().currentToken() != XContentParser.Token.START_OBJECT) {
throw new IllegalArgumentException(
"[_semantic_text_inference] fields must be a json object, expected a START_OBJECT but got: "
"[_semantic_text] produced inference must be a json object, expected a START_OBJECT but got: "
+ context.parser().currentToken()
);
}

MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false).createChildContext(NAME);

// TODO Can we validate that semantic text fields have actual text values?
for (XContentParser.Token token = context.parser().nextToken(); token != XContentParser.Token.END_OBJECT; token = context.parser()
.nextToken()) {
if (token != XContentParser.Token.FIELD_NAME) {
throw new IllegalArgumentException("[semantic_text] fields expect an object with field names, found " + token);
throw new IllegalArgumentException("[semantic_text] produced inference expect an object with field names, found " + token);
}

String fieldName = context.parser().currentName();

Mapper mapper = context.getMapper(fieldName);
if (mapper == null) {
// Not a field we have mapped? Must be model output, skip it
context.parser().nextToken();
context.path().setWithinLeafObject(true);
Map<String, Object> fieldMap = context.parser().map();
context.path().setWithinLeafObject(false);
continue;
}
if (SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) {
if ((mapper == null) || SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) {
throw new IllegalArgumentException(
"Found [" + fieldName + "] in inference values, but it is not registered as a semantic_text field type"
);
}

context.parser().nextToken();
SparseVectorFieldMapper sparseVectorFieldMapper = new SparseVectorFieldMapper.Builder(fieldName).build(mapperBuilderContext);
sparseVectorFieldMapper.parse(context);
NestedObjectMapper.Builder nestedBuilder = new NestedObjectMapper.Builder(
fieldName,
context.indexSettings().getIndexVersionCreated()
);
SparseVectorFieldMapper.Builder sparseVectorFieldMapperBuilder = new SparseVectorFieldMapper.Builder(
"inference"
);
nestedBuilder.add(sparseVectorFieldMapperBuilder);
TextFieldMapper.Builder textFieldMapperBuilder = new TextFieldMapper.Builder("text", context.indexAnalyzers()).index(false)
.store(false);
nestedBuilder.add(textFieldMapperBuilder);
NestedObjectMapper nestedObjectMapper = nestedBuilder.build(mapperBuilderContext);

if (context.parser().nextToken() != XContentParser.Token.START_ARRAY) {
throw new IllegalArgumentException(
"[_semantic_text] produced inference must be an array of objects, expected a START_ARRAY but got: "
+ context.parser().currentToken()
);
}
for (token = context.parser().nextToken(); token != XContentParser.Token.END_ARRAY; token = context.parser()
.nextToken()) {
DocumentParserContext nestedContext = context.createNestedContext(nestedObjectMapper);

if (token != XContentParser.Token.START_OBJECT) {
throw new IllegalArgumentException(
"each [_semantic_text] produced inference must be an object, expected a START_OBJECT but got: "
+ context.parser().currentToken()
);
}

Set<String> visitedFields = new HashSet<>();
for (token = context.parser().nextToken(); token != XContentParser.Token.END_OBJECT; token = context.parser()
.nextToken()) {

if (token != XContentParser.Token.FIELD_NAME) {
throw new IllegalArgumentException(
"each [semantic_text] produced objects fields expect an object with field names, found " + token
);
}

String inferenceField = context.parser().currentName();
FieldMapper childNestedMapper = (FieldMapper) nestedObjectMapper.getMapper(inferenceField);
if (childNestedMapper == null) {
throw new IllegalArgumentException("unexpected inference result field name: " + inferenceField);
}
context.parser().nextToken();
childNestedMapper.parse(nestedContext);
visitedFields.add(inferenceField);
}
if (visitedFields.size() != nestedObjectMapper.getChildren().size()) {
throw new IllegalArgumentException("unexpected inference fields: " + visitedFields);
}
}

}
}

Expand Down

0 comments on commit b1d5bd0

Please sign in to comment.