From 979e34cba772cddd35b988e36f6cf33c44dc351c Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Thu, 21 Nov 2024 14:40:51 +0000 Subject: [PATCH] Add an inference metadata fields instead of storing the inference in the original field --- .../elasticsearch/index/IndexVersions.java | 1 + .../index/engine/TranslogDirectoryReader.java | 2 +- .../index/mapper/DocumentParserContext.java | 13 ++- .../mapper/InferenceMetadataFieldsMapper.java | 90 +++++++++++++++++++ .../index/mapper/SourceFieldMapper.java | 23 +++-- .../elasticsearch/indices/IndicesModule.java | 2 + .../xpack/inference/InferencePlugin.java | 2 +- .../ShardBulkInferenceActionFilter.java | 38 ++++++-- .../mapper/SemanticTextFieldMapper.java | 12 ++- .../ShardBulkInferenceActionFilterTests.java | 8 +- 10 files changed, 174 insertions(+), 17 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/index/mapper/InferenceMetadataFieldsMapper.java diff --git a/server/src/main/java/org/elasticsearch/index/IndexVersions.java b/server/src/main/java/org/elasticsearch/index/IndexVersions.java index 7a5f469a57fa1..df6c30770e505 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexVersions.java +++ b/server/src/main/java/org/elasticsearch/index/IndexVersions.java @@ -135,6 +135,7 @@ private static Version parseUnchecked(String version) { public static final IndexVersion LOGSDB_DEFAULT_IGNORE_DYNAMIC_BEYOND_LIMIT = def(9_001_00_0, Version.LUCENE_10_0_0); public static final IndexVersion TIME_BASED_K_ORDERED_DOC_ID = def(9_002_00_0, Version.LUCENE_10_0_0); public static final IndexVersion DEPRECATE_SOURCE_MODE_MAPPER = def(9_003_00_0, Version.LUCENE_10_0_0); + public static final IndexVersion INFERENCE_METADATA_FIELDS = def(9_004_00_0, Version.LUCENE_10_0_0); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/server/src/main/java/org/elasticsearch/index/engine/TranslogDirectoryReader.java b/server/src/main/java/org/elasticsearch/index/engine/TranslogDirectoryReader.java index 0f772b49bf92b..73f49021805bc 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/TranslogDirectoryReader.java +++ b/server/src/main/java/org/elasticsearch/index/engine/TranslogDirectoryReader.java @@ -440,7 +440,7 @@ private void readStoredFieldsDirectly(StoredFieldVisitor visitor) throws IOExcep SourceFieldMapper mapper = mappingLookup.getMapping().getMetadataMapperByClass(SourceFieldMapper.class); if (mapper != null) { try { - sourceBytes = mapper.applyFilters(sourceBytes, null); + sourceBytes = mapper.applyFilters(null, sourceBytes, null); } catch (IOException e) { throw new IOException("Failed to reapply filters after reading from translog", e); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java index 51e4e9f4c1b5e..83ac81c768269 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java @@ -42,10 +42,10 @@ public abstract class DocumentParserContext { /** * Wraps a given context while allowing to override some of its behaviour by re-implementing some of the non final methods */ - private static class Wrapper extends DocumentParserContext { + static class Wrapper extends DocumentParserContext { private final DocumentParserContext in; - private Wrapper(ObjectMapper parent, DocumentParserContext in) { + Wrapper(ObjectMapper parent, DocumentParserContext in) { super(parent, parent.dynamic == null ? in.dynamic : parent.dynamic, in); this.in = in; } @@ -60,6 +60,11 @@ public boolean isWithinCopyTo() { return in.isWithinCopyTo(); } + @Override + public boolean isWithinInferenceMetadata() { + return in.isWithinInferenceMetadata(); + } + @Override public ContentPath path() { return in.path(); @@ -648,6 +653,10 @@ public boolean isWithinCopyTo() { return false; } + public boolean isWithinInferenceMetadata() { + return false; + } + boolean inArrayScope() { return currentScope == Scope.ARRAY; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/InferenceMetadataFieldsMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/InferenceMetadataFieldsMapper.java new file mode 100644 index 0000000000000..76638c362e549 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/mapper/InferenceMetadataFieldsMapper.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.mapper; + +import org.apache.lucene.search.Query; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.index.query.QueryShardException; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Map; + +public class InferenceMetadataFieldsMapper extends MetadataFieldMapper { + public static final String NAME = "_inference_fields"; + public static final String CONTENT_TYPE = "_inference_fields"; + + private static final InferenceMetadataFieldsMapper INSTANCE = new InferenceMetadataFieldsMapper(); + + public static final TypeParser PARSER = new FixedTypeParser(c -> INSTANCE); + + public static final class InferenceFieldType extends MappedFieldType { + private static InferenceFieldType INSTANCE = new InferenceFieldType(); + + public InferenceFieldType() { + super(NAME, false, false, false, TextSearchInfo.NONE, Map.of()); + } + + @Override + public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + // TODO: return the map from the individual semantic text fields? + return null; + } + + @Override + public String typeName() { + return CONTENT_TYPE; + } + + @Override + public Query termQuery(Object value, SearchExecutionContext context) { + throw new QueryShardException( + context, + "[" + name() + "] field which is of type [" + typeName() + "], does not support term queries" + ); + } + } + + private InferenceMetadataFieldsMapper() { + super(InferenceFieldType.INSTANCE); + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + protected boolean supportsParsingObject() { + return true; + } + + @Override + protected void parseCreateField(DocumentParserContext context) throws IOException { + XContentParser parser = context.parser(); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.currentToken(), parser); + String fieldName = parser.currentName(); + Mapper mapper = context.mappingLookup().getMapper(fieldName); + if (mapper != null && mapper instanceof InferenceFieldMapper && mapper instanceof FieldMapper fieldMapper) { + fieldMapper.parseCreateField(new DocumentParserContext.Wrapper(context.parent(), context) { + @Override + public boolean isWithinInferenceMetadata() { + return true; + } + }); + } else { + throw new IllegalArgumentException("Illegal inference field [" + fieldName + "] found."); + } + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java index e5b12f748543f..08565e8b6ae3a 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java @@ -413,7 +413,7 @@ public boolean isComplete() { public void preParse(DocumentParserContext context) throws IOException { BytesReference originalSource = context.sourceToParse().source(); XContentType contentType = context.sourceToParse().getXContentType(); - final BytesReference adaptedSource = applyFilters(originalSource, contentType); + final BytesReference adaptedSource = applyFilters(context.mappingLookup(), originalSource, contentType); if (adaptedSource != null) { final BytesRef ref = adaptedSource.toBytesRef(); @@ -430,13 +430,26 @@ public void preParse(DocumentParserContext context) throws IOException { } @Nullable - public BytesReference applyFilters(@Nullable BytesReference originalSource, @Nullable XContentType contentType) throws IOException { - if (stored() == false) { + public BytesReference applyFilters( + @Nullable MappingLookup mappingLookup, + @Nullable BytesReference originalSource, + @Nullable XContentType contentType + ) throws IOException { + if (stored() == false || originalSource == null) { return null; } - if (originalSource != null && sourceFilter != null) { + var modSourceFilter = sourceFilter; + if (mappingLookup != null && mappingLookup.inferenceFields().isEmpty() == false) { + String[] modExcludes = new String[excludes != null ? excludes.length + 1 : 1]; + if (excludes != null) { + System.arraycopy(excludes, 0, modExcludes, 0, excludes.length); + } + modExcludes[modExcludes.length - 1] = InferenceMetadataFieldsMapper.NAME; + modSourceFilter = new SourceFilter(includes, modExcludes); + } + if (modSourceFilter != null) { // Percolate and tv APIs may not set the source and that is ok, because these APIs will not index any data - return Source.fromBytes(originalSource, contentType).filter(sourceFilter).internalSourceRef(); + return Source.fromBytes(originalSource, contentType).filter(modSourceFilter).internalSourceRef(); } else { return originalSource; } diff --git a/server/src/main/java/org/elasticsearch/indices/IndicesModule.java b/server/src/main/java/org/elasticsearch/indices/IndicesModule.java index 340bff4e1c852..1ca4aee887c1b 100644 --- a/server/src/main/java/org/elasticsearch/indices/IndicesModule.java +++ b/server/src/main/java/org/elasticsearch/indices/IndicesModule.java @@ -42,6 +42,7 @@ import org.elasticsearch.index.mapper.IgnoredSourceFieldMapper; import org.elasticsearch.index.mapper.IndexFieldMapper; import org.elasticsearch.index.mapper.IndexModeFieldMapper; +import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; import org.elasticsearch.index.mapper.IpFieldMapper; import org.elasticsearch.index.mapper.IpScriptFieldType; import org.elasticsearch.index.mapper.KeywordFieldMapper; @@ -272,6 +273,7 @@ private static Map initBuiltInMetadataMa builtInMetadataMappers.put(SeqNoFieldMapper.NAME, SeqNoFieldMapper.PARSER); builtInMetadataMappers.put(DocCountFieldMapper.NAME, DocCountFieldMapper.PARSER); builtInMetadataMappers.put(DataStreamTimestampFieldMapper.NAME, DataStreamTimestampFieldMapper.PARSER); + builtInMetadataMappers.put(InferenceMetadataFieldsMapper.NAME, InferenceMetadataFieldsMapper.PARSER); // _field_names must be added last so that it has a chance to see all the other mappers builtInMetadataMappers.put(FieldNamesFieldMapper.NAME, FieldNamesFieldMapper.PARSER); return Collections.unmodifiableMap(builtInMetadataMappers); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 62405a2e9f7de..6495d23e905c8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -232,7 +232,7 @@ public Collection createComponents(PluginServices services) { } inferenceServiceRegistry.set(registry); - var actionFilter = new ShardBulkInferenceActionFilter(registry, modelRegistry); + var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), registry, modelRegistry); shardBulkInferenceActionFilter.set(actionFilter); var meterRegistry = services.telemetryProvider().getMeterRegistry(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index dd59230e575c4..8c3a58bc745ec 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -24,11 +24,15 @@ import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.IndexVersions; +import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; import org.elasticsearch.inference.InferenceService; @@ -68,15 +72,26 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { protected static final int DEFAULT_BATCH_SIZE = 512; + private final ClusterService clusterService; private final InferenceServiceRegistry inferenceServiceRegistry; private final ModelRegistry modelRegistry; private final int batchSize; - public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry) { - this(inferenceServiceRegistry, modelRegistry, DEFAULT_BATCH_SIZE); + public ShardBulkInferenceActionFilter( + ClusterService clusterService, + InferenceServiceRegistry inferenceServiceRegistry, + ModelRegistry modelRegistry + ) { + this(clusterService, inferenceServiceRegistry, modelRegistry, DEFAULT_BATCH_SIZE); } - public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry, int batchSize) { + public ShardBulkInferenceActionFilter( + ClusterService clusterService, + InferenceServiceRegistry inferenceServiceRegistry, + ModelRegistry modelRegistry, + int batchSize + ) { + this.clusterService = clusterService; this.inferenceServiceRegistry = inferenceServiceRegistry; this.modelRegistry = modelRegistry; this.batchSize = batchSize; @@ -112,7 +127,8 @@ private void processBulkShardRequest( BulkShardRequest bulkShardRequest, Runnable onCompletion ) { - new AsyncBulkShardInferenceAction(fieldInferenceMap, bulkShardRequest, onCompletion).run(); + var index = clusterService.state().getMetadata().index(bulkShardRequest.index()); + new AsyncBulkShardInferenceAction(index.getCreationVersion(), fieldInferenceMap, bulkShardRequest, onCompletion).run(); } private record InferenceProvider(InferenceService service, Model model) {} @@ -165,16 +181,19 @@ void addFailure(Exception exc) { } private class AsyncBulkShardInferenceAction implements Runnable { + private final IndexVersion indexCreatedVersion; private final Map fieldInferenceMap; private final BulkShardRequest bulkShardRequest; private final Runnable onCompletion; private final AtomicArray inferenceResults; private AsyncBulkShardInferenceAction( + IndexVersion indexCreatedVersion, Map fieldInferenceMap, BulkShardRequest bulkShardRequest, Runnable onCompletion ) { + this.indexCreatedVersion = indexCreatedVersion; this.fieldInferenceMap = fieldInferenceMap; this.bulkShardRequest = bulkShardRequest; this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length); @@ -379,6 +398,8 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); var newDocMap = indexRequest.sourceAsMap(); + Map inferenceFieldsMap = new HashMap<>(); + final boolean addMetadataField = indexCreatedVersion.onOrAfter(IndexVersions.INFERENCE_METADATA_FIELDS); for (var entry : response.responses.entrySet()) { var fieldName = entry.getKey(); var responses = entry.getValue(); @@ -397,7 +418,14 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons ), indexRequest.getContentType() ); - SemanticTextFieldMapper.insertValue(fieldName, newDocMap, result); + if (addMetadataField) { + inferenceFieldsMap.put(fieldName, result); + } else { + SemanticTextFieldMapper.insertValue(fieldName, newDocMap, result); + } + } + if (addMetadataField) { + newDocMap.put(InferenceMetadataFieldsMapper.NAME, inferenceFieldsMap); } indexRequest.source(newDocMap, indexRequest.getContentType()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 89a54ffe29177..da219d001cd0b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -21,6 +21,7 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.BlockLoader; @@ -286,6 +287,11 @@ public FieldMapper.Builder getMergeBuilder() { @Override protected void parseCreateField(DocumentParserContext context) throws IOException { + if (context.isWithinInferenceMetadata() == false) { + assert indexSettings.getIndexVersionCreated().onOrAfter(IndexVersions.INFERENCE_METADATA_FIELDS); + // ignore original text value + return; + } XContentParser parser = context.parser(); if (parser.currentToken() == XContentParser.Token.VALUE_NULL) { return; @@ -495,8 +501,10 @@ public Query existsQuery(SearchExecutionContext context) { @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - // Redirect the fetcher to load the original values of the field - return SourceValueFetcher.toString(getOriginalTextFieldName(name()), context, format); + String fieldName = context.getIndexSettings().getIndexVersionCreated().onOrAfter(IndexVersions.INFERENCE_METADATA_FIELDS) + ? name() + : getOriginalTextFieldName(name()); + return SourceValueFetcher.toString(fieldName, context, format); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 770e6e3cb9cf4..ca1aba11187d5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -320,7 +320,13 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); when(inferenceServiceRegistry.getService(any())).thenReturn(Optional.of(inferenceService)); - ShardBulkInferenceActionFilter filter = new ShardBulkInferenceActionFilter(inferenceServiceRegistry, modelRegistry, batchSize); + // TODO: add cluster service + ShardBulkInferenceActionFilter filter = new ShardBulkInferenceActionFilter( + null, + inferenceServiceRegistry, + modelRegistry, + batchSize + ); return filter; }