diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java index 87df6559426f6..83c81ebd0e0cf 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java @@ -7,7 +7,13 @@ package org.elasticsearch.xpack.ml.mapper; +import org.apache.lucene.document.FeatureField; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.lucene.search.Queries; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.FieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; @@ -18,6 +24,8 @@ import org.elasticsearch.index.mapper.ValueFetcher; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.index.search.ESToParentBlockJoinQuery; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import java.io.IOException; import java.util.Map; @@ -27,7 +35,6 @@ public class SemanticTextFieldMapper extends FieldMapper { public static final String CONTENT_TYPE = "semantic_text"; - private static SemanticTextFieldMapper toType(FieldMapper in) { return (SemanticTextFieldMapper) in; } @@ -57,10 +64,6 @@ protected Parameter[] getParameters() { return new Parameter[] { modelId, meta }; } - private SemanticTextFieldType buildFieldType(MapperBuilderContext context) { - return new SemanticTextFieldType(context.buildFullName(name), modelId.getValue(), meta.getValue()); - } - @Override public SemanticTextFieldMapper build(MapperBuilderContext context) { return new SemanticTextFieldMapper( @@ -107,6 +110,29 @@ public Query termQuery(Object value, SearchExecutionContext context) { return sparseVectorFieldType.termQuery(value, context); } + public Query textExpansionQuery(TextExpansionResults expansionResults, SearchExecutionContext context) { + String fieldName = name() + "." + "inference"; + BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder(); + for (var weightedToken : expansionResults.getWeightedTokens()) { + queryBuilder.add( + new BooleanClause( + FeatureField.newLinearQuery(fieldName, indexedValueForSearch(weightedToken.token()), weightedToken.weight()), + BooleanClause.Occur.SHOULD + ) + ); + } + queryBuilder.setMinimumNumberShouldMatch(1); + var parentFilter = context.bitsetFilter(Queries.newNonNestedFilter(context.indexVersionCreated())); + return new ESToParentBlockJoinQuery(queryBuilder.build(), parentFilter, ScoreMode.Total, name()); + } + + private static String indexedValueForSearch(Object value) { + if (value instanceof BytesRef) { + return ((BytesRef) value).utf8ToString(); + } + return value.toString(); + } + @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { return SourceValueFetcher.identity(name(), context, format); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SemanticQueryBuilder.java similarity index 89% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SemanticQueryBuilder.java index 6833bcb780363..f85f68e657699 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SemanticQueryBuilder.java @@ -5,10 +5,9 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.queries; +package org.elasticsearch.xpack.ml.queries; import org.apache.lucene.search.Query; -import org.apache.lucene.search.join.ScoreMode; import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; @@ -16,20 +15,19 @@ import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.lucene.search.Queries; import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.index.search.ESToParentBlockJoinQuery; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; -import org.elasticsearch.xpack.inference.action.InferenceAction; +import org.elasticsearch.xpack.ml.mapper.SemanticTextFieldMapper; import java.io.IOException; import java.util.List; @@ -46,7 +44,6 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder inferenceResultsList = inferenceResultsSupplier.get(); + if (inferenceResultsList == null) { + throw new IllegalArgumentException("No inference retrieved for field " + fieldName); + } + if (inferenceResultsList.size() != 1) { + throw new IllegalArgumentException("received multiple inference results for field " + fieldName); + } + + InferenceResults inferenceResults = inferenceResultsList.get(0); + if (inferenceResults instanceof TextExpansionResults expansionResults) { + SemanticTextFieldMapper.SemanticTextFieldType mapper = (SemanticTextFieldMapper.SemanticTextFieldType) context.getFieldType( + fieldName + ); + return mapper.textExpansionQuery(expansionResults, context); + } + + throw new IllegalArgumentException( + "field [" + fieldName + "] does not use a model that outputs sparse vector inference results" + ); } @Override