Skip to content

Commit

Permalink
Add an inference metadata fields instead of storing the inference in …
Browse files Browse the repository at this point in the history
…the original field
  • Loading branch information
jimczi committed Nov 21, 2024
1 parent 5500a5e commit 979e34c
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ private static Version parseUnchecked(String version) {
public static final IndexVersion LOGSDB_DEFAULT_IGNORE_DYNAMIC_BEYOND_LIMIT = def(9_001_00_0, Version.LUCENE_10_0_0);
public static final IndexVersion TIME_BASED_K_ORDERED_DOC_ID = def(9_002_00_0, Version.LUCENE_10_0_0);
public static final IndexVersion DEPRECATE_SOURCE_MODE_MAPPER = def(9_003_00_0, Version.LUCENE_10_0_0);
public static final IndexVersion INFERENCE_METADATA_FIELDS = def(9_004_00_0, Version.LUCENE_10_0_0);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ private void readStoredFieldsDirectly(StoredFieldVisitor visitor) throws IOExcep
SourceFieldMapper mapper = mappingLookup.getMapping().getMetadataMapperByClass(SourceFieldMapper.class);
if (mapper != null) {
try {
sourceBytes = mapper.applyFilters(sourceBytes, null);
sourceBytes = mapper.applyFilters(null, sourceBytes, null);
} catch (IOException e) {
throw new IOException("Failed to reapply filters after reading from translog", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ public abstract class DocumentParserContext {
/**
* Wraps a given context while allowing to override some of its behaviour by re-implementing some of the non final methods
*/
private static class Wrapper extends DocumentParserContext {
static class Wrapper extends DocumentParserContext {
private final DocumentParserContext in;

private Wrapper(ObjectMapper parent, DocumentParserContext in) {
Wrapper(ObjectMapper parent, DocumentParserContext in) {
super(parent, parent.dynamic == null ? in.dynamic : parent.dynamic, in);
this.in = in;
}
Expand All @@ -60,6 +60,11 @@ public boolean isWithinCopyTo() {
return in.isWithinCopyTo();
}

@Override
public boolean isWithinInferenceMetadata() {
return in.isWithinInferenceMetadata();
}

@Override
public ContentPath path() {
return in.path();
Expand Down Expand Up @@ -648,6 +653,10 @@ public boolean isWithinCopyTo() {
return false;
}

public boolean isWithinInferenceMetadata() {
return false;
}

boolean inArrayScope() {
return currentScope == Scope.ARRAY;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.mapper;

import org.apache.lucene.search.Query;
import org.elasticsearch.common.xcontent.XContentParserUtils;
import org.elasticsearch.index.query.QueryShardException;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.Map;

public class InferenceMetadataFieldsMapper extends MetadataFieldMapper {
public static final String NAME = "_inference_fields";
public static final String CONTENT_TYPE = "_inference_fields";

private static final InferenceMetadataFieldsMapper INSTANCE = new InferenceMetadataFieldsMapper();

public static final TypeParser PARSER = new FixedTypeParser(c -> INSTANCE);

public static final class InferenceFieldType extends MappedFieldType {
private static InferenceFieldType INSTANCE = new InferenceFieldType();

public InferenceFieldType() {
super(NAME, false, false, false, TextSearchInfo.NONE, Map.of());
}

@Override
public ValueFetcher valueFetcher(SearchExecutionContext context, String format) {
// TODO: return the map from the individual semantic text fields?
return null;
}

@Override
public String typeName() {
return CONTENT_TYPE;
}

@Override
public Query termQuery(Object value, SearchExecutionContext context) {
throw new QueryShardException(
context,
"[" + name() + "] field which is of type [" + typeName() + "], does not support term queries"
);
}
}

private InferenceMetadataFieldsMapper() {
super(InferenceFieldType.INSTANCE);
}

@Override
protected String contentType() {
return CONTENT_TYPE;
}

@Override
protected boolean supportsParsingObject() {
return true;
}

@Override
protected void parseCreateField(DocumentParserContext context) throws IOException {
XContentParser parser = context.parser();
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.currentToken(), parser);
String fieldName = parser.currentName();
Mapper mapper = context.mappingLookup().getMapper(fieldName);
if (mapper != null && mapper instanceof InferenceFieldMapper && mapper instanceof FieldMapper fieldMapper) {
fieldMapper.parseCreateField(new DocumentParserContext.Wrapper(context.parent(), context) {
@Override
public boolean isWithinInferenceMetadata() {
return true;
}
});
} else {
throw new IllegalArgumentException("Illegal inference field [" + fieldName + "] found.");
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ public boolean isComplete() {
public void preParse(DocumentParserContext context) throws IOException {
BytesReference originalSource = context.sourceToParse().source();
XContentType contentType = context.sourceToParse().getXContentType();
final BytesReference adaptedSource = applyFilters(originalSource, contentType);
final BytesReference adaptedSource = applyFilters(context.mappingLookup(), originalSource, contentType);

if (adaptedSource != null) {
final BytesRef ref = adaptedSource.toBytesRef();
Expand All @@ -430,13 +430,26 @@ public void preParse(DocumentParserContext context) throws IOException {
}

@Nullable
public BytesReference applyFilters(@Nullable BytesReference originalSource, @Nullable XContentType contentType) throws IOException {
if (stored() == false) {
public BytesReference applyFilters(
@Nullable MappingLookup mappingLookup,
@Nullable BytesReference originalSource,
@Nullable XContentType contentType
) throws IOException {
if (stored() == false || originalSource == null) {
return null;
}
if (originalSource != null && sourceFilter != null) {
var modSourceFilter = sourceFilter;
if (mappingLookup != null && mappingLookup.inferenceFields().isEmpty() == false) {
String[] modExcludes = new String[excludes != null ? excludes.length + 1 : 1];
if (excludes != null) {
System.arraycopy(excludes, 0, modExcludes, 0, excludes.length);
}
modExcludes[modExcludes.length - 1] = InferenceMetadataFieldsMapper.NAME;
modSourceFilter = new SourceFilter(includes, modExcludes);
}
if (modSourceFilter != null) {
// Percolate and tv APIs may not set the source and that is ok, because these APIs will not index any data
return Source.fromBytes(originalSource, contentType).filter(sourceFilter).internalSourceRef();
return Source.fromBytes(originalSource, contentType).filter(modSourceFilter).internalSourceRef();
} else {
return originalSource;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.elasticsearch.index.mapper.IgnoredSourceFieldMapper;
import org.elasticsearch.index.mapper.IndexFieldMapper;
import org.elasticsearch.index.mapper.IndexModeFieldMapper;
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
import org.elasticsearch.index.mapper.IpFieldMapper;
import org.elasticsearch.index.mapper.IpScriptFieldType;
import org.elasticsearch.index.mapper.KeywordFieldMapper;
Expand Down Expand Up @@ -272,6 +273,7 @@ private static Map<String, MetadataFieldMapper.TypeParser> initBuiltInMetadataMa
builtInMetadataMappers.put(SeqNoFieldMapper.NAME, SeqNoFieldMapper.PARSER);
builtInMetadataMappers.put(DocCountFieldMapper.NAME, DocCountFieldMapper.PARSER);
builtInMetadataMappers.put(DataStreamTimestampFieldMapper.NAME, DataStreamTimestampFieldMapper.PARSER);
builtInMetadataMappers.put(InferenceMetadataFieldsMapper.NAME, InferenceMetadataFieldsMapper.PARSER);
// _field_names must be added last so that it has a chance to see all the other mappers
builtInMetadataMappers.put(FieldNamesFieldMapper.NAME, FieldNamesFieldMapper.PARSER);
return Collections.unmodifiableMap(builtInMetadataMappers);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public Collection<?> createComponents(PluginServices services) {
}
inferenceServiceRegistry.set(registry);

var actionFilter = new ShardBulkInferenceActionFilter(registry, modelRegistry);
var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), registry, modelRegistry);
shardBulkInferenceActionFilter.set(actionFilter);

var meterRegistry = services.telemetryProvider().getMeterRegistry();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@
import org.elasticsearch.action.support.RefCountingRunnable;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.IndexVersions;
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.InferenceService;
Expand Down Expand Up @@ -68,15 +72,26 @@
public class ShardBulkInferenceActionFilter implements MappedActionFilter {
protected static final int DEFAULT_BATCH_SIZE = 512;

private final ClusterService clusterService;
private final InferenceServiceRegistry inferenceServiceRegistry;
private final ModelRegistry modelRegistry;
private final int batchSize;

public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry) {
this(inferenceServiceRegistry, modelRegistry, DEFAULT_BATCH_SIZE);
public ShardBulkInferenceActionFilter(
ClusterService clusterService,
InferenceServiceRegistry inferenceServiceRegistry,
ModelRegistry modelRegistry
) {
this(clusterService, inferenceServiceRegistry, modelRegistry, DEFAULT_BATCH_SIZE);
}

public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry, int batchSize) {
public ShardBulkInferenceActionFilter(
ClusterService clusterService,
InferenceServiceRegistry inferenceServiceRegistry,
ModelRegistry modelRegistry,
int batchSize
) {
this.clusterService = clusterService;
this.inferenceServiceRegistry = inferenceServiceRegistry;
this.modelRegistry = modelRegistry;
this.batchSize = batchSize;
Expand Down Expand Up @@ -112,7 +127,8 @@ private void processBulkShardRequest(
BulkShardRequest bulkShardRequest,
Runnable onCompletion
) {
new AsyncBulkShardInferenceAction(fieldInferenceMap, bulkShardRequest, onCompletion).run();
var index = clusterService.state().getMetadata().index(bulkShardRequest.index());
new AsyncBulkShardInferenceAction(index.getCreationVersion(), fieldInferenceMap, bulkShardRequest, onCompletion).run();
}

private record InferenceProvider(InferenceService service, Model model) {}
Expand Down Expand Up @@ -165,16 +181,19 @@ void addFailure(Exception exc) {
}

private class AsyncBulkShardInferenceAction implements Runnable {
private final IndexVersion indexCreatedVersion;
private final Map<String, InferenceFieldMetadata> fieldInferenceMap;
private final BulkShardRequest bulkShardRequest;
private final Runnable onCompletion;
private final AtomicArray<FieldInferenceResponseAccumulator> inferenceResults;

private AsyncBulkShardInferenceAction(
IndexVersion indexCreatedVersion,
Map<String, InferenceFieldMetadata> fieldInferenceMap,
BulkShardRequest bulkShardRequest,
Runnable onCompletion
) {
this.indexCreatedVersion = indexCreatedVersion;
this.fieldInferenceMap = fieldInferenceMap;
this.bulkShardRequest = bulkShardRequest;
this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length);
Expand Down Expand Up @@ -379,6 +398,8 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons

final IndexRequest indexRequest = getIndexRequestOrNull(item.request());
var newDocMap = indexRequest.sourceAsMap();
Map<String, Object> inferenceFieldsMap = new HashMap<>();
final boolean addMetadataField = indexCreatedVersion.onOrAfter(IndexVersions.INFERENCE_METADATA_FIELDS);
for (var entry : response.responses.entrySet()) {
var fieldName = entry.getKey();
var responses = entry.getValue();
Expand All @@ -397,7 +418,14 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
),
indexRequest.getContentType()
);
SemanticTextFieldMapper.insertValue(fieldName, newDocMap, result);
if (addMetadataField) {
inferenceFieldsMap.put(fieldName, result);
} else {
SemanticTextFieldMapper.insertValue(fieldName, newDocMap, result);
}
}
if (addMetadataField) {
newDocMap.put(InferenceMetadataFieldsMapper.NAME, inferenceFieldsMap);
}
indexRequest.source(newDocMap, indexRequest.getContentType());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.IndexVersions;
import org.elasticsearch.index.fielddata.FieldDataContext;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.mapper.BlockLoader;
Expand Down Expand Up @@ -286,6 +287,11 @@ public FieldMapper.Builder getMergeBuilder() {

@Override
protected void parseCreateField(DocumentParserContext context) throws IOException {
if (context.isWithinInferenceMetadata() == false) {
assert indexSettings.getIndexVersionCreated().onOrAfter(IndexVersions.INFERENCE_METADATA_FIELDS);
// ignore original text value
return;
}
XContentParser parser = context.parser();
if (parser.currentToken() == XContentParser.Token.VALUE_NULL) {
return;
Expand Down Expand Up @@ -495,8 +501,10 @@ public Query existsQuery(SearchExecutionContext context) {

@Override
public ValueFetcher valueFetcher(SearchExecutionContext context, String format) {
// Redirect the fetcher to load the original values of the field
return SourceValueFetcher.toString(getOriginalTextFieldName(name()), context, format);
String fieldName = context.getIndexSettings().getIndexVersionCreated().onOrAfter(IndexVersions.INFERENCE_METADATA_FIELDS)
? name()
: getOriginalTextFieldName(name());
return SourceValueFetcher.toString(fieldName, context, format);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,13 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool

InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class);
when(inferenceServiceRegistry.getService(any())).thenReturn(Optional.of(inferenceService));
ShardBulkInferenceActionFilter filter = new ShardBulkInferenceActionFilter(inferenceServiceRegistry, modelRegistry, batchSize);
// TODO: add cluster service
ShardBulkInferenceActionFilter filter = new ShardBulkInferenceActionFilter(
null,
inferenceServiceRegistry,
modelRegistry,
batchSize
);
return filter;
}

Expand Down

0 comments on commit 979e34c

Please sign in to comment.