From 148736230cee64485ec4b82b02757f3721d648be Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 1 Nov 2023 12:47:52 +0100 Subject: [PATCH] Generate a separate top-level field for adding inference, and creates an additional field mapper to actually insert it into Lucene --- .../index/mapper/RootObjectMapper.java | 15 +- .../index/mapper/SemanticTextFieldMapper.java | 36 +---- .../SemanticTextInferenceFieldMapper.java | 142 ++++++++++++++++++ .../elasticsearch/indices/IndicesModule.java | 2 + ...FieldInferenceBulkRequestPreprocessor.java | 23 ++- 5 files changed, 171 insertions(+), 47 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/index/mapper/SemanticTextInferenceFieldMapper.java diff --git a/server/src/main/java/org/elasticsearch/index/mapper/RootObjectMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/RootObjectMapper.java index b5f968165548..b2051e979438 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/RootObjectMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/RootObjectMapper.java @@ -106,12 +106,25 @@ public RootObjectMapper.Builder addRuntimeFields(Map runti @Override public RootObjectMapper build(MapperBuilderContext context) { + // Check whether we should add field inference mapper + // TODO Find a better place for doing this + Map mappers = buildMappers(context); + boolean hasInference = mappers.values() + .stream() + .anyMatch(mapper -> mapper.typeName().equals(SemanticTextFieldMapper.CONTENT_TYPE)); + + if (hasInference) { + SemanticTextInferenceFieldMapper semanticTextInferenceFieldMapper = new SemanticTextInferenceFieldMapper.Builder().build( + context + ); + mappers.put(semanticTextInferenceFieldMapper.simpleName(), semanticTextInferenceFieldMapper); + } return new RootObjectMapper( name, enabled, subobjects, dynamic, - buildMappers(context), + mappers, runtimeFields, dynamicDateTimeFormatters, dynamicTemplates, diff --git a/server/src/main/java/org/elasticsearch/index/mapper/SemanticTextFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/SemanticTextFieldMapper.java index 4bc79628268f..bbde9e93a186 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/SemanticTextFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/SemanticTextFieldMapper.java @@ -146,42 +146,12 @@ public FieldMapper.Builder getMergeBuilder() { @Override public void parse(DocumentParserContext context) throws IOException { - if (context.parser().currentToken() != XContentParser.Token.START_OBJECT) { + if (context.parser().currentToken() != XContentParser.Token.VALUE_STRING) { throw new IllegalArgumentException( - "[semantic_text] fields must be a json object, expected a START_OBJECT but got: " + context.parser().currentToken() + "[semantic_text] fields must be a text value but got: " + context.parser().currentToken() ); } - - boolean textFound = false; - boolean inferenceFound = false; - for (XContentParser.Token token = context.parser().nextToken(); token != XContentParser.Token.END_OBJECT; token = context.parser() - .nextToken()) { - if (token != XContentParser.Token.FIELD_NAME) { - throw new IllegalArgumentException("[semantic_text] fields expect an object with field names, found " + token); - } - - String fieldName = context.parser().currentName(); - XContentParser.Token valueToken = context.parser().nextToken(); - switch (fieldName) { - case TEXT_SUBFIELD_NAME: - context.doc().add(new StringField(name() + TEXT_SUBFIELD_NAME, context.parser().textOrNull(), Field.Store.NO)); - textFound = true; - break; - case SPARSE_VECTOR_SUBFIELD_NAME: - sparseVectorFieldMapper.parse(context); - inferenceFound = true; - break; - default: - throw new IllegalArgumentException("Unexpected subfield value: " + fieldName); - } - } - - if (textFound == false) { - throw new IllegalArgumentException("[semantic_text] value does not contain [" + TEXT_SUBFIELD_NAME + "] subfield"); - } - if (inferenceFound == false) { - throw new IllegalArgumentException("[semantic_text] value does not contain [" + SPARSE_VECTOR_SUBFIELD_NAME + "] subfield"); - } + context.doc().add(new StringField(name(), context.parser().textOrNull(), Field.Store.NO)); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/mapper/SemanticTextInferenceFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/SemanticTextInferenceFieldMapper.java new file mode 100644 index 000000000000..fc7eff34653f --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/mapper/SemanticTextInferenceFieldMapper.java @@ -0,0 +1,142 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.mapper; + +import org.apache.lucene.search.Query; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.SparseVectorFieldType; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Map; + +public class SemanticTextInferenceFieldMapper extends FieldMapper { + + public static final String CONTENT_TYPE = "semantic_text_inference"; + + public static final String FIELD_NAME = "_semantic_text_inference"; + public static final String SPARSE_VECTOR_SUBFIELD_NAME = TaskType.SPARSE_EMBEDDING.toString(); + + private static SemanticTextInferenceFieldMapper toType(FieldMapper in) { + return (SemanticTextInferenceFieldMapper) in; + } + + public static class Builder extends FieldMapper.Builder { + + private final Parameter> meta = Parameter.metaParam(); + + public Builder() { + super(FIELD_NAME); + } + + @Override + protected Parameter[] getParameters() { + return new Parameter[] { meta }; + } + + private SemanticTextInferenceFieldType buildFieldType(MapperBuilderContext context) { + return new SemanticTextInferenceFieldType(context.buildFullName(name), meta.getValue()); + } + + @Override + public SemanticTextInferenceFieldMapper build(MapperBuilderContext context) { + return new SemanticTextInferenceFieldMapper(name(), new SemanticTextInferenceFieldType(name(), meta.getValue()), copyTo, this); + } + } + + public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(), notInMultiFields(CONTENT_TYPE)); + + public static class SemanticTextInferenceFieldType extends SimpleMappedFieldType { + + private SparseVectorFieldType sparseVectorFieldType; + + public SemanticTextInferenceFieldType(String name, Map meta) { + super(name, true, false, false, TextSearchInfo.NONE, meta); + } + + public SparseVectorFieldType getSparseVectorFieldType() { + return this.sparseVectorFieldType; + } + + @Override + public String typeName() { + return CONTENT_TYPE; + } + + @Override + public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + return SourceValueFetcher.identity(name(), context, format); + } + + @Override + public Query termQuery(Object value, SearchExecutionContext context) { + return sparseVectorFieldType.termQuery(value, context); + } + } + + private SemanticTextInferenceFieldMapper(String simpleName, MappedFieldType mappedFieldType, CopyTo copyTo, Builder builder) { + super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); + } + + @Override + public FieldMapper.Builder getMergeBuilder() { + return new Builder().init(this); + } + + @Override + public void parse(DocumentParserContext context) throws IOException { + + if (context.parser().currentToken() != XContentParser.Token.START_OBJECT) { + throw new IllegalArgumentException( + "[_semantic_text_inference] fields must be a json object, expected a START_OBJECT but got: " + + context.parser().currentToken() + ); + } + + MapperBuilderContext mapperBuilderContext = new MapperBuilderContext(FIELD_NAME, false, false); + + // TODO Can we validate that semantic text fields have actual text values? + for (XContentParser.Token token = context.parser().nextToken(); token != XContentParser.Token.END_OBJECT; token = context.parser() + .nextToken()) { + if (token != XContentParser.Token.FIELD_NAME) { + throw new IllegalArgumentException("[semantic_text] fields expect an object with field names, found " + token); + } + + String fieldName = context.parser().currentName(); + + Mapper mapper = context.getMapper(fieldName); + if (SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) { + throw new IllegalArgumentException( + "Found [" + fieldName + "] in inference values, but it is not registered as a semantic_text field type" + ); + } + + context.parser().nextToken(); + SparseVectorFieldMapper sparseVectorFieldMapper = new SparseVectorFieldMapper.Builder(fieldName).build(mapperBuilderContext); + sparseVectorFieldMapper.parse(context); + } + } + + @Override + protected void parseCreateField(DocumentParserContext context) { + throw new AssertionError("parse is implemented directly"); + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + public SemanticTextInferenceFieldType fieldType() { + return (SemanticTextInferenceFieldType) super.fieldType(); + } +} diff --git a/server/src/main/java/org/elasticsearch/indices/IndicesModule.java b/server/src/main/java/org/elasticsearch/indices/IndicesModule.java index b94e3ca7785e..7688506a5764 100644 --- a/server/src/main/java/org/elasticsearch/indices/IndicesModule.java +++ b/server/src/main/java/org/elasticsearch/indices/IndicesModule.java @@ -56,6 +56,7 @@ import org.elasticsearch.index.mapper.RoutingFieldMapper; import org.elasticsearch.index.mapper.RuntimeField; import org.elasticsearch.index.mapper.SemanticTextFieldMapper; +import org.elasticsearch.index.mapper.SemanticTextInferenceFieldMapper; import org.elasticsearch.index.mapper.SeqNoFieldMapper; import org.elasticsearch.index.mapper.SourceFieldMapper; import org.elasticsearch.index.mapper.TextFieldMapper; @@ -199,6 +200,7 @@ public static Map getMappers(List mappe mappers.put(DenseVectorFieldMapper.CONTENT_TYPE, DenseVectorFieldMapper.PARSER); mappers.put(SparseVectorFieldMapper.CONTENT_TYPE, SparseVectorFieldMapper.PARSER); mappers.put(SemanticTextFieldMapper.CONTENT_TYPE, SemanticTextFieldMapper.PARSER); + mappers.put(SemanticTextInferenceFieldMapper.CONTENT_TYPE, SemanticTextInferenceFieldMapper.PARSER); for (MapperPlugin mapperPlugin : mapperPlugins) { for (Map.Entry entry : mapperPlugin.getMappers().entrySet()) { diff --git a/server/src/main/java/org/elasticsearch/ingest/FieldInferenceBulkRequestPreprocessor.java b/server/src/main/java/org/elasticsearch/ingest/FieldInferenceBulkRequestPreprocessor.java index 21a99365255d..758fcff4dade 100644 --- a/server/src/main/java/org/elasticsearch/ingest/FieldInferenceBulkRequestPreprocessor.java +++ b/server/src/main/java/org/elasticsearch/ingest/FieldInferenceBulkRequestPreprocessor.java @@ -20,12 +20,11 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.core.Releasable; import org.elasticsearch.index.Index; -import org.elasticsearch.index.mapper.SemanticTextFieldMapper; +import org.elasticsearch.index.mapper.SemanticTextInferenceFieldMapper; import org.elasticsearch.indices.IndicesService; import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.internal.DocumentParsingObserver; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.BiConsumer; @@ -67,14 +66,15 @@ protected void processIndexRequest( assert indexRequest.isFieldInferenceDone() == false; IngestDocument ingestDocument = newIngestDocument(indexRequest); - List fieldNames = ingestDocument.getSource().entrySet() + List fieldNames = ingestDocument.getSource() + .entrySet() .stream() .filter(entry -> fieldNeedsInference(indexRequest, entry.getKey(), entry.getValue())) .map(Map.Entry::getKey) .toList(); // Runs inference sequentially. This makes easier sync and removes the problem of having multiple - // BulkItemResponses for a single bulk request in TransportBulkAction.unwrappingSingleItemBulkResponse + // BulkItemResponses for a single bulk request in TransportBulkAction.unwrappingSingleItemBulkResponse runInferenceForFields(indexRequest, fieldNames, refs.acquire(), slot, ingestDocument, onFailure); } @@ -166,14 +166,11 @@ private void runInferenceForFields( client.execute(InferenceAction.INSTANCE, inferenceRequest, new ActionListener<>() { @Override public void onResponse(InferenceAction.Response response) { - // Transform into two subfields, one with the actual text and other with the inference - Map newFieldValue = new HashMap<>(); - newFieldValue.put(SemanticTextFieldMapper.TEXT_SUBFIELD_NAME, fieldValue); - newFieldValue.put( - SemanticTextFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME, - response.getResult().asMap(fieldName).get(fieldName) + // Transform into another top-level subfield + ingestDocument.setFieldValue( + SemanticTextInferenceFieldMapper.FIELD_NAME + "." + fieldName, + response.getResult().asMap(fieldName).get(fieldName) ); - ingestDocument.setFieldValue(fieldName, newFieldValue); // Run inference for next fields runInferenceForFields(indexRequest, nextFieldNames, ref, position, ingestDocument, onFailure); @@ -183,8 +180,8 @@ public void onResponse(InferenceAction.Response response) { public void onFailure(Exception e) { // Wrap exception in an illegal argument exception, as there is a problem with the model or model config onFailure.accept( - position, - new IllegalArgumentException("Error performing inference for field [" + fieldName + "]: " + e.getMessage(), e) + position, + new IllegalArgumentException("Error performing inference for field [" + fieldName + "]: " + e.getMessage(), e) ); ref.close(); }