Skip to content

Commit

Permalink
Generate a separate top-level field for adding inference, and creates…
Browse files Browse the repository at this point in the history
… an additional field mapper to actually insert it into Lucene
  • Loading branch information
carlosdelest committed Nov 1, 2023
1 parent 338ecd7 commit 1487362
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,25 @@ public RootObjectMapper.Builder addRuntimeFields(Map<String, RuntimeField> runti

@Override
public RootObjectMapper build(MapperBuilderContext context) {
// Check whether we should add field inference mapper
// TODO Find a better place for doing this
Map<String, Mapper> mappers = buildMappers(context);
boolean hasInference = mappers.values()
.stream()
.anyMatch(mapper -> mapper.typeName().equals(SemanticTextFieldMapper.CONTENT_TYPE));

if (hasInference) {
SemanticTextInferenceFieldMapper semanticTextInferenceFieldMapper = new SemanticTextInferenceFieldMapper.Builder().build(
context
);
mappers.put(semanticTextInferenceFieldMapper.simpleName(), semanticTextInferenceFieldMapper);
}
return new RootObjectMapper(
name,
enabled,
subobjects,
dynamic,
buildMappers(context),
mappers,
runtimeFields,
dynamicDateTimeFormatters,
dynamicTemplates,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,42 +146,12 @@ public FieldMapper.Builder getMergeBuilder() {
@Override
public void parse(DocumentParserContext context) throws IOException {

if (context.parser().currentToken() != XContentParser.Token.START_OBJECT) {
if (context.parser().currentToken() != XContentParser.Token.VALUE_STRING) {
throw new IllegalArgumentException(
"[semantic_text] fields must be a json object, expected a START_OBJECT but got: " + context.parser().currentToken()
"[semantic_text] fields must be a text value but got: " + context.parser().currentToken()
);
}

boolean textFound = false;
boolean inferenceFound = false;
for (XContentParser.Token token = context.parser().nextToken(); token != XContentParser.Token.END_OBJECT; token = context.parser()
.nextToken()) {
if (token != XContentParser.Token.FIELD_NAME) {
throw new IllegalArgumentException("[semantic_text] fields expect an object with field names, found " + token);
}

String fieldName = context.parser().currentName();
XContentParser.Token valueToken = context.parser().nextToken();
switch (fieldName) {
case TEXT_SUBFIELD_NAME:
context.doc().add(new StringField(name() + TEXT_SUBFIELD_NAME, context.parser().textOrNull(), Field.Store.NO));
textFound = true;
break;
case SPARSE_VECTOR_SUBFIELD_NAME:
sparseVectorFieldMapper.parse(context);
inferenceFound = true;
break;
default:
throw new IllegalArgumentException("Unexpected subfield value: " + fieldName);
}
}

if (textFound == false) {
throw new IllegalArgumentException("[semantic_text] value does not contain [" + TEXT_SUBFIELD_NAME + "] subfield");
}
if (inferenceFound == false) {
throw new IllegalArgumentException("[semantic_text] value does not contain [" + SPARSE_VECTOR_SUBFIELD_NAME + "] subfield");
}
context.doc().add(new StringField(name(), context.parser().textOrNull(), Field.Store.NO));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.index.mapper;

import org.apache.lucene.search.Query;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.SparseVectorFieldType;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.Map;

public class SemanticTextInferenceFieldMapper extends FieldMapper {

public static final String CONTENT_TYPE = "semantic_text_inference";

public static final String FIELD_NAME = "_semantic_text_inference";
public static final String SPARSE_VECTOR_SUBFIELD_NAME = TaskType.SPARSE_EMBEDDING.toString();

private static SemanticTextInferenceFieldMapper toType(FieldMapper in) {
return (SemanticTextInferenceFieldMapper) in;
}

public static class Builder extends FieldMapper.Builder {

private final Parameter<Map<String, String>> meta = Parameter.metaParam();

public Builder() {
super(FIELD_NAME);
}

@Override
protected Parameter<?>[] getParameters() {
return new Parameter<?>[] { meta };
}

private SemanticTextInferenceFieldType buildFieldType(MapperBuilderContext context) {
return new SemanticTextInferenceFieldType(context.buildFullName(name), meta.getValue());
}

@Override
public SemanticTextInferenceFieldMapper build(MapperBuilderContext context) {
return new SemanticTextInferenceFieldMapper(name(), new SemanticTextInferenceFieldType(name(), meta.getValue()), copyTo, this);
}
}

public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(), notInMultiFields(CONTENT_TYPE));

public static class SemanticTextInferenceFieldType extends SimpleMappedFieldType {

private SparseVectorFieldType sparseVectorFieldType;

public SemanticTextInferenceFieldType(String name, Map<String, String> meta) {
super(name, true, false, false, TextSearchInfo.NONE, meta);
}

public SparseVectorFieldType getSparseVectorFieldType() {
return this.sparseVectorFieldType;
}

@Override
public String typeName() {
return CONTENT_TYPE;
}

@Override
public ValueFetcher valueFetcher(SearchExecutionContext context, String format) {
return SourceValueFetcher.identity(name(), context, format);
}

@Override
public Query termQuery(Object value, SearchExecutionContext context) {
return sparseVectorFieldType.termQuery(value, context);
}
}

private SemanticTextInferenceFieldMapper(String simpleName, MappedFieldType mappedFieldType, CopyTo copyTo, Builder builder) {
super(simpleName, mappedFieldType, MultiFields.empty(), copyTo);
}

@Override
public FieldMapper.Builder getMergeBuilder() {
return new Builder().init(this);
}

@Override
public void parse(DocumentParserContext context) throws IOException {

if (context.parser().currentToken() != XContentParser.Token.START_OBJECT) {
throw new IllegalArgumentException(
"[_semantic_text_inference] fields must be a json object, expected a START_OBJECT but got: "
+ context.parser().currentToken()
);
}

MapperBuilderContext mapperBuilderContext = new MapperBuilderContext(FIELD_NAME, false, false);

// TODO Can we validate that semantic text fields have actual text values?
for (XContentParser.Token token = context.parser().nextToken(); token != XContentParser.Token.END_OBJECT; token = context.parser()
.nextToken()) {
if (token != XContentParser.Token.FIELD_NAME) {
throw new IllegalArgumentException("[semantic_text] fields expect an object with field names, found " + token);
}

String fieldName = context.parser().currentName();

Mapper mapper = context.getMapper(fieldName);
if (SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) {
throw new IllegalArgumentException(
"Found [" + fieldName + "] in inference values, but it is not registered as a semantic_text field type"
);
}

context.parser().nextToken();
SparseVectorFieldMapper sparseVectorFieldMapper = new SparseVectorFieldMapper.Builder(fieldName).build(mapperBuilderContext);
sparseVectorFieldMapper.parse(context);
}
}

@Override
protected void parseCreateField(DocumentParserContext context) {
throw new AssertionError("parse is implemented directly");
}

@Override
protected String contentType() {
return CONTENT_TYPE;
}

@Override
public SemanticTextInferenceFieldType fieldType() {
return (SemanticTextInferenceFieldType) super.fieldType();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.elasticsearch.index.mapper.RoutingFieldMapper;
import org.elasticsearch.index.mapper.RuntimeField;
import org.elasticsearch.index.mapper.SemanticTextFieldMapper;
import org.elasticsearch.index.mapper.SemanticTextInferenceFieldMapper;
import org.elasticsearch.index.mapper.SeqNoFieldMapper;
import org.elasticsearch.index.mapper.SourceFieldMapper;
import org.elasticsearch.index.mapper.TextFieldMapper;
Expand Down Expand Up @@ -199,6 +200,7 @@ public static Map<String, Mapper.TypeParser> getMappers(List<MapperPlugin> mappe
mappers.put(DenseVectorFieldMapper.CONTENT_TYPE, DenseVectorFieldMapper.PARSER);
mappers.put(SparseVectorFieldMapper.CONTENT_TYPE, SparseVectorFieldMapper.PARSER);
mappers.put(SemanticTextFieldMapper.CONTENT_TYPE, SemanticTextFieldMapper.PARSER);
mappers.put(SemanticTextInferenceFieldMapper.CONTENT_TYPE, SemanticTextInferenceFieldMapper.PARSER);

for (MapperPlugin mapperPlugin : mapperPlugins) {
for (Map.Entry<String, Mapper.TypeParser> entry : mapperPlugin.getMappers().entrySet()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.mapper.SemanticTextFieldMapper;
import org.elasticsearch.index.mapper.SemanticTextInferenceFieldMapper;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.plugins.internal.DocumentParsingObserver;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
Expand Down Expand Up @@ -67,14 +66,15 @@ protected void processIndexRequest(
assert indexRequest.isFieldInferenceDone() == false;

IngestDocument ingestDocument = newIngestDocument(indexRequest);
List<String> fieldNames = ingestDocument.getSource().entrySet()
List<String> fieldNames = ingestDocument.getSource()
.entrySet()
.stream()
.filter(entry -> fieldNeedsInference(indexRequest, entry.getKey(), entry.getValue()))
.map(Map.Entry::getKey)
.toList();

// Runs inference sequentially. This makes easier sync and removes the problem of having multiple
// BulkItemResponses for a single bulk request in TransportBulkAction.unwrappingSingleItemBulkResponse
// BulkItemResponses for a single bulk request in TransportBulkAction.unwrappingSingleItemBulkResponse
runInferenceForFields(indexRequest, fieldNames, refs.acquire(), slot, ingestDocument, onFailure);

}
Expand Down Expand Up @@ -166,14 +166,11 @@ private void runInferenceForFields(
client.execute(InferenceAction.INSTANCE, inferenceRequest, new ActionListener<>() {
@Override
public void onResponse(InferenceAction.Response response) {
// Transform into two subfields, one with the actual text and other with the inference
Map<String, Object> newFieldValue = new HashMap<>();
newFieldValue.put(SemanticTextFieldMapper.TEXT_SUBFIELD_NAME, fieldValue);
newFieldValue.put(
SemanticTextFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME,
response.getResult().asMap(fieldName).get(fieldName)
// Transform into another top-level subfield
ingestDocument.setFieldValue(
SemanticTextInferenceFieldMapper.FIELD_NAME + "." + fieldName,
response.getResult().asMap(fieldName).get(fieldName)
);
ingestDocument.setFieldValue(fieldName, newFieldValue);

// Run inference for next fields
runInferenceForFields(indexRequest, nextFieldNames, ref, position, ingestDocument, onFailure);
Expand All @@ -183,8 +180,8 @@ public void onResponse(InferenceAction.Response response) {
public void onFailure(Exception e) {
// Wrap exception in an illegal argument exception, as there is a problem with the model or model config
onFailure.accept(
position,
new IllegalArgumentException("Error performing inference for field [" + fieldName + "]: " + e.getMessage(), e)
position,
new IllegalArgumentException("Error performing inference for field [" + fieldName + "]: " + e.getMessage(), e)
);
ref.close();
}
Expand Down

0 comments on commit 1487362

Please sign in to comment.