Skip to content

Commit

Permalink
Working implementation for indexing inference as a metadata field
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Nov 22, 2023
1 parent ab97838 commit 5f3cd33
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.elasticsearch.index.analysis.CharFilterFactory;
import org.elasticsearch.index.analysis.TokenizerFactory;
import org.elasticsearch.index.mapper.Mapper;
import org.elasticsearch.index.mapper.MetadataFieldMapper;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.indices.AssociatedIndexDescriptor;
import org.elasticsearch.indices.SystemIndexDescriptor;
Expand Down Expand Up @@ -364,6 +365,7 @@
import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTaskExecutor;
import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutor;
import org.elasticsearch.xpack.ml.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.ml.mapper.SemanticTextInferenceResultFieldMapper;
import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
Expand Down Expand Up @@ -2278,7 +2280,18 @@ public void signalShutdown(Collection<String> shutdownNodeIds) {

@Override
public Map<String, Mapper.TypeParser> getMappers() {
return Map.of(SemanticTextFieldMapper.CONTENT_TYPE, SemanticTextFieldMapper.PARSER);
return Map.of(
SemanticTextFieldMapper.CONTENT_TYPE,
SemanticTextFieldMapper.PARSER
);
}

@Override
public Map<String, MetadataFieldMapper.TypeParser> getMetadataMappers() {
return Map.of(
SemanticTextInferenceResultFieldMapper.CONTENT_TYPE,
SemanticTextInferenceResultFieldMapper.PARSER
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.ingest.Processor;
import org.elasticsearch.ingest.WrappingProcessor;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
import org.elasticsearch.xpack.ml.mapper.SemanticTextInferenceResultFieldMapper;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;

import java.util.List;
Expand Down Expand Up @@ -57,7 +58,7 @@ private Processor createWrappedProcessor() {

private InferenceProcessor createInferenceProcessor(String modelId, Set<String> fields) {
List<InferenceProcessor.Factory.InputConfig> inputConfigs = fields.stream()
.map(f -> new InferenceProcessor.Factory.InputConfig(f, "ml.inference", f, Map.of()))
.map(f -> new InferenceProcessor.Factory.InputConfig(f, SemanticTextInferenceResultFieldMapper.NAME, f, Map.of()))
.toList();

return InferenceProcessor.fromInputFieldConfiguration(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import org.elasticsearch.index.mapper.SourceValueFetcher;
import org.elasticsearch.index.mapper.TextSearchInfo;
import org.elasticsearch.index.mapper.ValueFetcher;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.SparseVectorFieldType;
import org.elasticsearch.index.query.SearchExecutionContext;

import java.io.IOException;
Expand All @@ -28,8 +26,6 @@ public class SemanticTextFieldMapper extends FieldMapper {

public static final String CONTENT_TYPE = "semantic_text";

public static final String TEXT_SUBFIELD_NAME = "text";
public static final String SPARSE_VECTOR_SUBFIELD_NAME = "inference";

private static SemanticTextFieldMapper toType(FieldMapper in) {
return (SemanticTextFieldMapper) in;
Expand Down Expand Up @@ -66,16 +62,11 @@ private SemanticTextFieldType buildFieldType(MapperBuilderContext context) {

@Override
public SemanticTextFieldMapper build(MapperBuilderContext context) {
String fullName = context.buildFullName(name);
String subfieldName = fullName + "." + SPARSE_VECTOR_SUBFIELD_NAME;
SparseVectorFieldMapper sparseVectorFieldMapper = new SparseVectorFieldMapper.Builder(subfieldName).build(context);
return new SemanticTextFieldMapper(
name(),
new SemanticTextFieldType(name(), modelId.getValue(), meta.getValue()),
modelId.getValue(),
sparseVectorFieldMapper,
copyTo,
this
copyTo
);
}
}
Expand All @@ -84,24 +75,17 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {

public static class SemanticTextFieldType extends SimpleMappedFieldType {

private final SparseVectorFieldType sparseVectorFieldType;

private final String modelId;

public SemanticTextFieldType(String name, String modelId, Map<String, String> meta) {
super(name, true, false, false, TextSearchInfo.NONE, meta);
this.sparseVectorFieldType = new SparseVectorFieldType(name + "." + SPARSE_VECTOR_SUBFIELD_NAME, meta);
this.modelId = modelId;
}

public String modelId() {
return modelId;
}

public SparseVectorFieldType getSparseVectorFieldType() {
return this.sparseVectorFieldType;
}

@Override
public String typeName() {
return CONTENT_TYPE;
Expand All @@ -111,36 +95,33 @@ public String getInferenceModel() {
return modelId;
}

@Override
public ValueFetcher valueFetcher(SearchExecutionContext context, String format) {
return SourceValueFetcher.identity(name(), context, format);
}

@Override
public Query termQuery(Object value, SearchExecutionContext context) {
return sparseVectorFieldType.termQuery(value, context);
return null;
}

@Override
public Query existsQuery(SearchExecutionContext context) {
return sparseVectorFieldType.existsQuery(context);
public ValueFetcher valueFetcher(SearchExecutionContext context, String format) {
return SourceValueFetcher.identity(name(), context, format);
}


}

private final String modelId;
private final SparseVectorFieldMapper sparseVectorFieldMapper;

private SemanticTextFieldMapper(
String simpleName,
MappedFieldType mappedFieldType,
String modelId,
SparseVectorFieldMapper sparseVectorFieldMapper,
CopyTo copyTo,
Builder builder
CopyTo copyTo
) {
super(simpleName, mappedFieldType, MultiFields.empty(), copyTo);
this.modelId = modelId;
this.sparseVectorFieldMapper = sparseVectorFieldMapper;
}

public String getModelId() {
return modelId;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.xpack.ml.mapper;

import org.apache.lucene.search.Query;
import org.elasticsearch.index.mapper.DocumentParserContext;
import org.elasticsearch.index.mapper.FieldMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.Mapper;
import org.elasticsearch.index.mapper.MapperBuilderContext;
import org.elasticsearch.index.mapper.MetadataFieldMapper;
import org.elasticsearch.index.mapper.SourceLoader;
import org.elasticsearch.index.mapper.SourceValueFetcher;
import org.elasticsearch.index.mapper.TextSearchInfo;
import org.elasticsearch.index.mapper.ValueFetcher;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.SparseVectorFieldType;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;

public class SemanticTextInferenceResultFieldMapper extends MetadataFieldMapper {

public static final String CONTENT_TYPE = "semantic_text_inference";

public static final String NAME = "_semantic_text_inference";
public static final String SPARSE_VECTOR_SUBFIELD_NAME = TaskType.SPARSE_EMBEDDING.toString();

private static final SemanticTextInferenceResultFieldMapper INSTANCE = new SemanticTextInferenceResultFieldMapper();

private static SemanticTextInferenceResultFieldMapper toType(FieldMapper in) {
return (SemanticTextInferenceResultFieldMapper) in;
}

public static final TypeParser PARSER = new FixedTypeParser(c -> new SemanticTextInferenceResultFieldMapper());

public static class SemanticTextInferenceFieldType extends MappedFieldType {

public static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType();
private SparseVectorFieldType sparseVectorFieldType;

public SemanticTextInferenceFieldType() {
super(NAME, true, false, false, TextSearchInfo.NONE, Collections.emptyMap());
}

@Override
public String typeName() {
return CONTENT_TYPE;
}

@Override
public ValueFetcher valueFetcher(SearchExecutionContext context, String format) {
return SourceValueFetcher.identity(name(), context, format);
}

@Override
public Query termQuery(Object value, SearchExecutionContext context) {
return sparseVectorFieldType.termQuery(value, context);
}
}

private SemanticTextInferenceResultFieldMapper() {
super(SemanticTextInferenceFieldType.INSTANCE);
}

@Override
public void parse(DocumentParserContext context) throws IOException {

if (context.parser().currentToken() != XContentParser.Token.START_OBJECT) {
throw new IllegalArgumentException(
"[_semantic_text_inference] fields must be a json object, expected a START_OBJECT but got: "
+ context.parser().currentToken()
);
}

MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false).createChildContext(NAME);

// TODO Can we validate that semantic text fields have actual text values?
for (XContentParser.Token token = context.parser().nextToken(); token != XContentParser.Token.END_OBJECT; token = context.parser()
.nextToken()) {
if (token != XContentParser.Token.FIELD_NAME) {
throw new IllegalArgumentException("[semantic_text] fields expect an object with field names, found " + token);
}

String fieldName = context.parser().currentName();

Mapper mapper = context.getMapper(fieldName);
if (mapper == null) {
// Not a field we have mapped? Must be model output, skip it
context.parser().nextToken();
context.path().setWithinLeafObject(true);
Map<String, Object> fieldMap = context.parser().map();
context.path().setWithinLeafObject(false);
continue;
}
if (SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) {
throw new IllegalArgumentException(
"Found [" + fieldName + "] in inference values, but it is not registered as a semantic_text field type"
);
}

context.parser().nextToken();
SparseVectorFieldMapper sparseVectorFieldMapper = new SparseVectorFieldMapper.Builder(fieldName).build(mapperBuilderContext);
sparseVectorFieldMapper.parse(context);
}
}

@Override
protected void parseCreateField(DocumentParserContext context) {
throw new AssertionError("parse is implemented directly");
}

@Override
public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() {
return SourceLoader.SyntheticFieldLoader.NOTHING;
}

@Override
protected String contentType() {
return CONTENT_TYPE;
}

@Override
public SemanticTextInferenceFieldType fieldType() {
return (SemanticTextInferenceFieldType) super.fieldType();
}
}

0 comments on commit 5f3cd33

Please sign in to comment.