From 5f3cd33ef2b687bd33dbd6fc6baa85745d76224e Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 22 Nov 2023 18:19:47 +0100 Subject: [PATCH] Working implementation for indexing inference as a metadata field --- .../xpack/ml/MachineLearning.java | 15 +- .../SemanticTextInferenceProcessor.java | 3 +- .../ml/mapper/SemanticTextFieldMapper.java | 41 ++---- ...emanticTextInferenceResultFieldMapper.java | 137 ++++++++++++++++++ 4 files changed, 164 insertions(+), 32 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 9c31599bd7c4f..19b3976ecc6e7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -49,6 +49,7 @@ import org.elasticsearch.index.analysis.CharFilterFactory; import org.elasticsearch.index.analysis.TokenizerFactory; import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MetadataFieldMapper; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.indices.AssociatedIndexDescriptor; import org.elasticsearch.indices.SystemIndexDescriptor; @@ -364,6 +365,7 @@ import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTaskExecutor; import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutor; import org.elasticsearch.xpack.ml.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.ml.mapper.SemanticTextInferenceResultFieldMapper; import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; @@ -2278,7 +2280,18 @@ public void signalShutdown(Collection shutdownNodeIds) { @Override public Map getMappers() { - return Map.of(SemanticTextFieldMapper.CONTENT_TYPE, SemanticTextFieldMapper.PARSER); + return Map.of( + SemanticTextFieldMapper.CONTENT_TYPE, + SemanticTextFieldMapper.PARSER + ); + } + + @Override + public Map getMetadataMappers() { + return Map.of( + SemanticTextInferenceResultFieldMapper.CONTENT_TYPE, + SemanticTextInferenceResultFieldMapper.PARSER + ); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/SemanticTextInferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/SemanticTextInferenceProcessor.java index f34846bb5b62c..7d4eab2ec52a6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/SemanticTextInferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/SemanticTextInferenceProcessor.java @@ -14,6 +14,7 @@ import org.elasticsearch.ingest.Processor; import org.elasticsearch.ingest.WrappingProcessor; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; +import org.elasticsearch.xpack.ml.mapper.SemanticTextInferenceResultFieldMapper; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import java.util.List; @@ -57,7 +58,7 @@ private Processor createWrappedProcessor() { private InferenceProcessor createInferenceProcessor(String modelId, Set fields) { List inputConfigs = fields.stream() - .map(f -> new InferenceProcessor.Factory.InputConfig(f, "ml.inference", f, Map.of())) + .map(f -> new InferenceProcessor.Factory.InputConfig(f, SemanticTextInferenceResultFieldMapper.NAME, f, Map.of())) .toList(); return InferenceProcessor.fromInputFieldConfiguration( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java index 8267c7b228c3e..18db4961d2385 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java @@ -16,8 +16,6 @@ import org.elasticsearch.index.mapper.SourceValueFetcher; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; -import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; -import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.SparseVectorFieldType; import org.elasticsearch.index.query.SearchExecutionContext; import java.io.IOException; @@ -28,8 +26,6 @@ public class SemanticTextFieldMapper extends FieldMapper { public static final String CONTENT_TYPE = "semantic_text"; - public static final String TEXT_SUBFIELD_NAME = "text"; - public static final String SPARSE_VECTOR_SUBFIELD_NAME = "inference"; private static SemanticTextFieldMapper toType(FieldMapper in) { return (SemanticTextFieldMapper) in; @@ -66,16 +62,11 @@ private SemanticTextFieldType buildFieldType(MapperBuilderContext context) { @Override public SemanticTextFieldMapper build(MapperBuilderContext context) { - String fullName = context.buildFullName(name); - String subfieldName = fullName + "." + SPARSE_VECTOR_SUBFIELD_NAME; - SparseVectorFieldMapper sparseVectorFieldMapper = new SparseVectorFieldMapper.Builder(subfieldName).build(context); return new SemanticTextFieldMapper( name(), new SemanticTextFieldType(name(), modelId.getValue(), meta.getValue()), modelId.getValue(), - sparseVectorFieldMapper, - copyTo, - this + copyTo ); } } @@ -84,13 +75,10 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { public static class SemanticTextFieldType extends SimpleMappedFieldType { - private final SparseVectorFieldType sparseVectorFieldType; - private final String modelId; public SemanticTextFieldType(String name, String modelId, Map meta) { super(name, true, false, false, TextSearchInfo.NONE, meta); - this.sparseVectorFieldType = new SparseVectorFieldType(name + "." + SPARSE_VECTOR_SUBFIELD_NAME, meta); this.modelId = modelId; } @@ -98,10 +86,6 @@ public String modelId() { return modelId; } - public SparseVectorFieldType getSparseVectorFieldType() { - return this.sparseVectorFieldType; - } - @Override public String typeName() { return CONTENT_TYPE; @@ -111,36 +95,33 @@ public String getInferenceModel() { return modelId; } - @Override - public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.identity(name(), context, format); - } - @Override public Query termQuery(Object value, SearchExecutionContext context) { - return sparseVectorFieldType.termQuery(value, context); + return null; } @Override - public Query existsQuery(SearchExecutionContext context) { - return sparseVectorFieldType.existsQuery(context); + public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + return SourceValueFetcher.identity(name(), context, format); } + + } private final String modelId; - private final SparseVectorFieldMapper sparseVectorFieldMapper; private SemanticTextFieldMapper( String simpleName, MappedFieldType mappedFieldType, String modelId, - SparseVectorFieldMapper sparseVectorFieldMapper, - CopyTo copyTo, - Builder builder + CopyTo copyTo ) { super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); this.modelId = modelId; - this.sparseVectorFieldMapper = sparseVectorFieldMapper; + } + + public String getModelId() { + return modelId; } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java new file mode 100644 index 0000000000000..30c253e46c690 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextInferenceResultFieldMapper.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.xpack.ml.mapper; + +import org.apache.lucene.search.Query; +import org.elasticsearch.index.mapper.DocumentParserContext; +import org.elasticsearch.index.mapper.FieldMapper; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.MetadataFieldMapper; +import org.elasticsearch.index.mapper.SourceLoader; +import org.elasticsearch.index.mapper.SourceValueFetcher; +import org.elasticsearch.index.mapper.TextSearchInfo; +import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.SparseVectorFieldType; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; + +public class SemanticTextInferenceResultFieldMapper extends MetadataFieldMapper { + + public static final String CONTENT_TYPE = "semantic_text_inference"; + + public static final String NAME = "_semantic_text_inference"; + public static final String SPARSE_VECTOR_SUBFIELD_NAME = TaskType.SPARSE_EMBEDDING.toString(); + + private static final SemanticTextInferenceResultFieldMapper INSTANCE = new SemanticTextInferenceResultFieldMapper(); + + private static SemanticTextInferenceResultFieldMapper toType(FieldMapper in) { + return (SemanticTextInferenceResultFieldMapper) in; + } + + public static final TypeParser PARSER = new FixedTypeParser(c -> new SemanticTextInferenceResultFieldMapper()); + + public static class SemanticTextInferenceFieldType extends MappedFieldType { + + public static final MappedFieldType INSTANCE = new SemanticTextInferenceFieldType(); + private SparseVectorFieldType sparseVectorFieldType; + + public SemanticTextInferenceFieldType() { + super(NAME, true, false, false, TextSearchInfo.NONE, Collections.emptyMap()); + } + + @Override + public String typeName() { + return CONTENT_TYPE; + } + + @Override + public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + return SourceValueFetcher.identity(name(), context, format); + } + + @Override + public Query termQuery(Object value, SearchExecutionContext context) { + return sparseVectorFieldType.termQuery(value, context); + } + } + + private SemanticTextInferenceResultFieldMapper() { + super(SemanticTextInferenceFieldType.INSTANCE); + } + + @Override + public void parse(DocumentParserContext context) throws IOException { + + if (context.parser().currentToken() != XContentParser.Token.START_OBJECT) { + throw new IllegalArgumentException( + "[_semantic_text_inference] fields must be a json object, expected a START_OBJECT but got: " + + context.parser().currentToken() + ); + } + + MapperBuilderContext mapperBuilderContext = MapperBuilderContext.root(false, false).createChildContext(NAME); + + // TODO Can we validate that semantic text fields have actual text values? + for (XContentParser.Token token = context.parser().nextToken(); token != XContentParser.Token.END_OBJECT; token = context.parser() + .nextToken()) { + if (token != XContentParser.Token.FIELD_NAME) { + throw new IllegalArgumentException("[semantic_text] fields expect an object with field names, found " + token); + } + + String fieldName = context.parser().currentName(); + + Mapper mapper = context.getMapper(fieldName); + if (mapper == null) { + // Not a field we have mapped? Must be model output, skip it + context.parser().nextToken(); + context.path().setWithinLeafObject(true); + Map fieldMap = context.parser().map(); + context.path().setWithinLeafObject(false); + continue; + } + if (SemanticTextFieldMapper.CONTENT_TYPE.equals(mapper.typeName()) == false) { + throw new IllegalArgumentException( + "Found [" + fieldName + "] in inference values, but it is not registered as a semantic_text field type" + ); + } + + context.parser().nextToken(); + SparseVectorFieldMapper sparseVectorFieldMapper = new SparseVectorFieldMapper.Builder(fieldName).build(mapperBuilderContext); + sparseVectorFieldMapper.parse(context); + } + } + + @Override + protected void parseCreateField(DocumentParserContext context) { + throw new AssertionError("parse is implemented directly"); + } + + @Override + public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { + return SourceLoader.SyntheticFieldLoader.NOTHING; + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + public SemanticTextInferenceFieldType fieldType() { + return (SemanticTextInferenceFieldType) super.fieldType(); + } +}