diff --git a/docs/changelog/118617.yaml b/docs/changelog/118617.yaml new file mode 100644 index 0000000000000..a8793a114e913 --- /dev/null +++ b/docs/changelog/118617.yaml @@ -0,0 +1,5 @@ +pr: 118617 +summary: Add support for `sparse_vector` queries against `semantic_text` fields +area: "Search" +type: enhancement +issues: [] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java index e9e4e90421adc..35cba890e5e0c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java @@ -90,7 +90,8 @@ public SparseVectorQueryBuilder( : (this.shouldPruneTokens ? new TokenPruningConfig() : null)); this.weightedTokensSupplier = null; - if (queryVectors == null ^ inferenceId == null == false) { + // Preserve BWC error messaging + if (queryVectors != null && inferenceId != null) { throw new IllegalArgumentException( "[" + NAME @@ -98,18 +99,24 @@ public SparseVectorQueryBuilder( + QUERY_VECTOR_FIELD.getPreferredName() + "] or [" + INFERENCE_ID_FIELD.getPreferredName() - + "]" + + "] for " + + ALLOWED_FIELD_TYPE + + " fields" ); } - if (inferenceId != null && query == null) { + + // Preserve BWC error messaging + if ((queryVectors == null) == (query == null)) { throw new IllegalArgumentException( "[" + NAME - + "] requires [" - + QUERY_FIELD.getPreferredName() - + "] when [" + + "] requires one of [" + + QUERY_VECTOR_FIELD.getPreferredName() + + "] or [" + INFERENCE_ID_FIELD.getPreferredName() - + "] is specified" + + "] for " + + ALLOWED_FIELD_TYPE + + " fields" ); } } @@ -143,6 +150,14 @@ public List getQueryVectors() { return queryVectors; } + public String getInferenceId() { + return inferenceId; + } + + public String getQuery() { + return query; + } + public boolean shouldPruneTokens() { return shouldPruneTokens; } @@ -176,7 +191,9 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep } builder.endObject(); } else { - builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); + if (inferenceId != null) { + builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); + } builder.field(QUERY_FIELD.getPreferredName(), query); } builder.field(PRUNE_FIELD.getPreferredName(), shouldPruneTokens); @@ -228,6 +245,11 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { shouldPruneTokens, tokenPruningConfig ); + } else if (inferenceId == null) { + // Edge case, where inference_id was not specified in the request, + // but we did not intercept this and rewrite to a query o field with + // pre-configured inference. So we trap here and output a nicer error message. + throw new IllegalArgumentException("inference_id required to perform vector search on query string"); } // TODO move this to xpack core and use inference APIs diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java index 7774b29bfc971..c272a773b2b59 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java @@ -260,16 +260,16 @@ public void testIllegalValues() { { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new SparseVectorQueryBuilder("field name", null, "model id") + () -> new SparseVectorQueryBuilder("field name", null, null) ); - assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id]", e.getMessage()); + assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id] for sparse_vector fields", e.getMessage()); } { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, () -> new SparseVectorQueryBuilder("field name", "model text", null) ); - assertEquals("[sparse_vector] requires [query] when [inference_id] is specified", e.getMessage()); + assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id] for sparse_vector fields", e.getMessage()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index a7a6004c0ebb2..3fd671383369d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -10,12 +10,14 @@ import org.elasticsearch.features.FeatureSpecification; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; -import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder; import java.util.Set; +import static org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED; +import static org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor.SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED; + /** * Provides inference features. */ @@ -43,7 +45,8 @@ public Set getTestFeatures() { SemanticTextFieldMapper.SEMANTIC_TEXT_ZERO_SIZE_FIX, SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX, SEMANTIC_TEXT_HIGHLIGHTER, - SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED + SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED, + SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index d2dc49a39888f..5513ab6fe7de2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -75,6 +75,7 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; +import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor; import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder; import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder; @@ -404,7 +405,7 @@ public List> getQueries() { @Override public List getQueryRewriteInterceptors() { - return List.of(new SemanticMatchQueryRewriteInterceptor()); + return List.of(new SemanticMatchQueryRewriteInterceptor(), new SemanticSparseVectorQueryRewriteInterceptor()); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java index a4a8123935c3e..fd1d65d00faf5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java @@ -7,24 +7,12 @@ package org.elasticsearch.xpack.inference.queries; -import org.elasticsearch.action.ResolvedIndices; -import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.features.NodeFeature; -import org.elasticsearch.index.mapper.IndexFieldMapper; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.MatchQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryRewriteContext; -import org.elasticsearch.index.query.TermQueryBuilder; -import org.elasticsearch.index.query.TermsQueryBuilder; -import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; - -public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterceptor { +public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor { public static final NodeFeature SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature( "search.semantic_match_query_rewrite_interception_supported" @@ -33,63 +21,45 @@ public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterce public SemanticMatchQueryRewriteInterceptor() {} @Override - public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) { + protected String getFieldName(QueryBuilder queryBuilder) { assert (queryBuilder instanceof MatchQueryBuilder); MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder; - QueryBuilder rewritten = queryBuilder; - ResolvedIndices resolvedIndices = context.getResolvedIndices(); - if (resolvedIndices != null) { - Collection indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values(); - List inferenceIndices = new ArrayList<>(); - List nonInferenceIndices = new ArrayList<>(); - for (IndexMetadata indexMetadata : indexMetadataCollection) { - String indexName = indexMetadata.getIndex().getName(); - InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(matchQueryBuilder.fieldName()); - if (inferenceFieldMetadata != null) { - inferenceIndices.add(indexName); - } else { - nonInferenceIndices.add(indexName); - } - } - - if (inferenceIndices.isEmpty()) { - return rewritten; - } else if (nonInferenceIndices.isEmpty() == false) { - BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - for (String inferenceIndexName : inferenceIndices) { - // Add a separate clause for each semantic query, because they may be using different inference endpoints - // TODO - consolidate this to a single clause once the semantic query supports multiple inference endpoints - boolQueryBuilder.should( - createSemanticSubQuery(inferenceIndexName, matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value()) - ); - } - boolQueryBuilder.should(createMatchSubQuery(nonInferenceIndices, matchQueryBuilder)); - rewritten = boolQueryBuilder; - } else { - rewritten = new SemanticQueryBuilder(matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value(), false); - } - } - - return rewritten; + return matchQueryBuilder.fieldName(); + } + @Override + protected String getQuery(QueryBuilder queryBuilder) { + assert (queryBuilder instanceof MatchQueryBuilder); + MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder; + return (String) matchQueryBuilder.value(); } @Override - public String getQueryName() { - return MatchQueryBuilder.NAME; + protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) { + return new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false); } - private QueryBuilder createSemanticSubQuery(String indexName, String fieldName, String value) { + @Override + protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( + QueryBuilder queryBuilder, + InferenceIndexInformationForField indexInformation + ) { + assert (queryBuilder instanceof MatchQueryBuilder); + MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder; BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true)); - boolQueryBuilder.filter(new TermQueryBuilder(IndexFieldMapper.NAME, indexName)); + boolQueryBuilder.should( + createSemanticSubQuery( + indexInformation.getInferenceIndices(), + matchQueryBuilder.fieldName(), + (String) matchQueryBuilder.value() + ) + ); + boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), matchQueryBuilder)); return boolQueryBuilder; } - private QueryBuilder createMatchSubQuery(List indices, MatchQueryBuilder matchQueryBuilder) { - BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - boolQueryBuilder.must(matchQueryBuilder); - boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices)); - return boolQueryBuilder; + @Override + public String getQueryName() { + return MatchQueryBuilder.NAME; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 30094ff7dbdfc..285739fe0936f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -144,6 +144,14 @@ public String getWriteableName() { return NAME; } + public String getFieldName() { + return fieldName; + } + + public String getQuery() { + return query; + } + @Override public TransportVersion getMinimalSupportedVersion() { return TransportVersions.V_8_15_0; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java new file mode 100644 index 0000000000000..bb76ef0be24e9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryRewriteInterceptor.java @@ -0,0 +1,148 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.elasticsearch.action.ResolvedIndices; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.index.mapper.IndexFieldMapper; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.index.query.TermsQueryBuilder; +import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Intercepts and adapts a query to be rewritten to work seamlessly on a semantic_text field. + */ +public abstract class SemanticQueryRewriteInterceptor implements QueryRewriteInterceptor { + + public SemanticQueryRewriteInterceptor() {} + + @Override + public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) { + String fieldName = getFieldName(queryBuilder); + ResolvedIndices resolvedIndices = context.getResolvedIndices(); + + if (resolvedIndices == null) { + // No resolved indices, so return the original query. + return queryBuilder; + } + + InferenceIndexInformationForField indexInformation = resolveIndicesForField(fieldName, resolvedIndices); + if (indexInformation.getInferenceIndices().isEmpty()) { + // No inference fields were identified, so return the original query. + return queryBuilder; + } else if (indexInformation.nonInferenceIndices().isEmpty() == false) { + // Combined case where the field name requested by this query contains both + // semantic_text and non-inference fields, so we have to combine queries per index + // containing each field type. + return buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation); + } else { + // The only fields we've identified are inference fields (e.g. semantic_text), + // so rewrite the entire query to work on a semantic_text field. + return buildInferenceQuery(queryBuilder, indexInformation); + } + } + + /** + * @param queryBuilder {@link QueryBuilder} + * @return The singular field name requested by the provided query builder. + */ + protected abstract String getFieldName(QueryBuilder queryBuilder); + + /** + * @param queryBuilder {@link QueryBuilder} + * @return The text/query string requested by the provided query builder. + */ + protected abstract String getQuery(QueryBuilder queryBuilder); + + /** + * Builds the inference query + * + * @param queryBuilder {@link QueryBuilder} + * @param indexInformation {@link InferenceIndexInformationForField} + * @return {@link QueryBuilder} + */ + protected abstract QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation); + + /** + * Builds a combined inference and non-inference query, + * which separates the different queries into appropriate indices based on field type. + * @param queryBuilder {@link QueryBuilder} + * @param indexInformation {@link InferenceIndexInformationForField} + * @return {@link QueryBuilder} + */ + protected abstract QueryBuilder buildCombinedInferenceAndNonInferenceQuery( + QueryBuilder queryBuilder, + InferenceIndexInformationForField indexInformation + ); + + private InferenceIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) { + Collection indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values(); + Map inferenceIndicesMetadata = new HashMap<>(); + List nonInferenceIndices = new ArrayList<>(); + for (IndexMetadata indexMetadata : indexMetadataCollection) { + String indexName = indexMetadata.getIndex().getName(); + InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(fieldName); + if (inferenceFieldMetadata != null) { + inferenceIndicesMetadata.put(indexName, inferenceFieldMetadata); + } else { + nonInferenceIndices.add(indexName); + } + } + + return new InferenceIndexInformationForField(fieldName, inferenceIndicesMetadata, nonInferenceIndices); + } + + protected QueryBuilder createSubQueryForIndices(Collection indices, QueryBuilder queryBuilder) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.must(queryBuilder); + boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices)); + return boolQueryBuilder; + } + + protected QueryBuilder createSemanticSubQuery(Collection indices, String fieldName, String value) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true)); + boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices)); + return boolQueryBuilder; + } + + /** + * Represents the indices and associated inference information for a field. + */ + public record InferenceIndexInformationForField( + String fieldName, + Map inferenceIndicesMetadata, + List nonInferenceIndices + ) { + + public Collection getInferenceIndices() { + return inferenceIndicesMetadata.keySet(); + } + + public Map> getInferenceIdsIndices() { + return inferenceIndicesMetadata.entrySet() + .stream() + .collect( + Collectors.groupingBy( + entry -> entry.getValue().getSearchInferenceId(), + Collectors.mapping(Map.Entry::getKey, Collectors.toList()) + ) + ); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java new file mode 100644 index 0000000000000..a35e83450c55a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.apache.lucene.search.join.ScoreMode; +import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; + +import java.util.List; +import java.util.Map; + +public class SemanticSparseVectorQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor { + + public static final NodeFeature SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature( + "search.semantic_sparse_vector_query_rewrite_interception_supported" + ); + + public SemanticSparseVectorQueryRewriteInterceptor() {} + + @Override + protected String getFieldName(QueryBuilder queryBuilder) { + assert (queryBuilder instanceof SparseVectorQueryBuilder); + SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder; + return sparseVectorQueryBuilder.getFieldName(); + } + + @Override + protected String getQuery(QueryBuilder queryBuilder) { + assert (queryBuilder instanceof SparseVectorQueryBuilder); + SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder; + return sparseVectorQueryBuilder.getQuery(); + } + + @Override + protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) { + Map> inferenceIdsIndices = indexInformation.getInferenceIdsIndices(); + if (inferenceIdsIndices.size() == 1) { + // Simple case, everything uses the same inference ID + String searchInferenceId = inferenceIdsIndices.keySet().iterator().next(); + return buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId); + } else { + // Multiple inference IDs, construct a boolean query + return buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices); + } + } + + private QueryBuilder buildInferenceQueryWithMultipleInferenceIds( + QueryBuilder queryBuilder, + Map> inferenceIdsIndices + ) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + for (String inferenceId : inferenceIdsIndices.keySet()) { + boolQueryBuilder.should( + createSubQueryForIndices( + inferenceIdsIndices.get(inferenceId), + buildNestedQueryFromSparseVectorQuery(queryBuilder, inferenceId) + ) + ); + } + return boolQueryBuilder; + } + + @Override + protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery( + QueryBuilder queryBuilder, + InferenceIndexInformationForField indexInformation + ) { + assert (queryBuilder instanceof SparseVectorQueryBuilder); + SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder; + Map> inferenceIdsIndices = indexInformation.getInferenceIdsIndices(); + + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.should( + createSubQueryForIndices( + indexInformation.nonInferenceIndices(), + createSubQueryForIndices(indexInformation.nonInferenceIndices(), sparseVectorQueryBuilder) + ) + ); + // We always perform nested subqueries on semantic_text fields, to support + // sparse_vector queries using query vectors. + for (String inferenceId : inferenceIdsIndices.keySet()) { + boolQueryBuilder.should( + createSubQueryForIndices( + inferenceIdsIndices.get(inferenceId), + buildNestedQueryFromSparseVectorQuery(sparseVectorQueryBuilder, inferenceId) + ) + ); + } + return boolQueryBuilder; + } + + private QueryBuilder buildNestedQueryFromSparseVectorQuery(QueryBuilder queryBuilder, String searchInferenceId) { + assert (queryBuilder instanceof SparseVectorQueryBuilder); + SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder; + return QueryBuilders.nestedQuery( + SemanticTextField.getChunksFieldName(sparseVectorQueryBuilder.getFieldName()), + new SparseVectorQueryBuilder( + SemanticTextField.getEmbeddingsFieldName(sparseVectorQueryBuilder.getFieldName()), + sparseVectorQueryBuilder.getQueryVectors(), + (sparseVectorQueryBuilder.getInferenceId() == null && sparseVectorQueryBuilder.getQuery() != null) + ? searchInferenceId + : sparseVectorQueryBuilder.getInferenceId(), + sparseVectorQueryBuilder.getQuery(), + sparseVectorQueryBuilder.shouldPruneTokens(), + sparseVectorQueryBuilder.getTokenPruningConfig() + ), + ScoreMode.Max + ); + } + + @Override + public String getQueryName() { + return SparseVectorQueryBuilder.NAME; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java new file mode 100644 index 0000000000000..47705c14d5941 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.index.query; + +import org.elasticsearch.action.MockResolvedIndices; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.ResolvedIndices; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor; +import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.Map; + +public class SemanticMatchQueryRewriteInterceptorTests extends ESTestCase { + + private TestThreadPool threadPool; + private NoOpClient client; + private Index index; + + private static final String FIELD_NAME = "fieldName"; + private static final String VALUE = "value"; + + @Before + public void setup() { + threadPool = createThreadPool(); + client = new NoOpClient(threadPool); + index = new Index(randomAlphaOfLength(10), randomAlphaOfLength(10)); + } + + @After + public void cleanup() { + threadPool.close(); + } + + public void testMatchQueryOnInferenceFieldIsInterceptedAndRewrittenToSemanticQuery() throws IOException { + Map inferenceFields = Map.of( + FIELD_NAME, + new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }) + ); + QueryRewriteContext context = createQueryRewriteContext(inferenceFields); + QueryBuilder original = createTestQueryBuilder(); + QueryBuilder rewritten = original.rewrite(context); + assertTrue( + "Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]", + rewritten instanceof InterceptedQueryBuilderWrapper + ); + InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten; + assertTrue(intercepted.queryBuilder instanceof SemanticQueryBuilder); + SemanticQueryBuilder semanticQueryBuilder = (SemanticQueryBuilder) intercepted.queryBuilder; + assertEquals(FIELD_NAME, semanticQueryBuilder.getFieldName()); + assertEquals(VALUE, semanticQueryBuilder.getQuery()); + } + + public void testMatchQueryOnNonInferenceFieldRemainsMatchQuery() throws IOException { + QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields + QueryBuilder original = createTestQueryBuilder(); + QueryBuilder rewritten = original.rewrite(context); + assertTrue( + "Expected query to remain match but was [" + rewritten.getClass().getName() + "]", + rewritten instanceof MatchQueryBuilder + ); + assertEquals(original, rewritten); + } + + private MatchQueryBuilder createTestQueryBuilder() { + return new MatchQueryBuilder(FIELD_NAME, VALUE); + } + + private QueryRewriteContext createQueryRewriteContext(Map inferenceFields) { + IndexMetadata indexMetadata = IndexMetadata.builder(index.getName()) + .settings( + Settings.builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()) + .put(IndexMetadata.SETTING_INDEX_UUID, index.getUUID()) + ) + .numberOfShards(1) + .numberOfReplicas(0) + .putInferenceFields(inferenceFields) + .build(); + + ResolvedIndices resolvedIndices = new MockResolvedIndices( + Map.of(), + new OriginalIndices(new String[] { index.getName() }, IndicesOptions.DEFAULT), + Map.of(index, indexMetadata) + ); + + return new QueryRewriteContext(null, client, null, resolvedIndices, null, createRewriteInterceptor()); + } + + private QueryRewriteInterceptor createRewriteInterceptor() { + return new SemanticMatchQueryRewriteInterceptor(); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java new file mode 100644 index 0000000000000..1adad1df7b29b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.index.query; + +import org.elasticsearch.action.MockResolvedIndices; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.ResolvedIndices; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; +import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.Map; + +public class SemanticSparseVectorQueryRewriteInterceptorTests extends ESTestCase { + + private TestThreadPool threadPool; + private NoOpClient client; + private Index index; + + private static final String FIELD_NAME = "fieldName"; + private static final String INFERENCE_ID = "inferenceId"; + private static final String QUERY = "query"; + + @Before + public void setup() { + threadPool = createThreadPool(); + client = new NoOpClient(threadPool); + index = new Index(randomAlphaOfLength(10), randomAlphaOfLength(10)); + } + + @After + public void cleanup() { + threadPool.close(); + } + + public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() throws IOException { + Map inferenceFields = Map.of( + FIELD_NAME, + new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }) + ); + QueryRewriteContext context = createQueryRewriteContext(inferenceFields); + QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY); + QueryBuilder rewritten = original.rewrite(context); + assertTrue( + "Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]", + rewritten instanceof InterceptedQueryBuilderWrapper + ); + InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten; + assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder); + NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder; + assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path()); + QueryBuilder innerQuery = nestedQueryBuilder.query(); + assertTrue(innerQuery instanceof SparseVectorQueryBuilder); + SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery; + assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName()); + assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId()); + assertEquals(QUERY, sparseVectorQueryBuilder.getQuery()); + } + + public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsInterceptedAndRewritten() throws IOException { + Map inferenceFields = Map.of( + FIELD_NAME, + new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }) + ); + QueryRewriteContext context = createQueryRewriteContext(inferenceFields); + QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, null, QUERY); + QueryBuilder rewritten = original.rewrite(context); + assertTrue( + "Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]", + rewritten instanceof InterceptedQueryBuilderWrapper + ); + InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten; + assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder); + NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder; + assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path()); + QueryBuilder innerQuery = nestedQueryBuilder.query(); + assertTrue(innerQuery instanceof SparseVectorQueryBuilder); + SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery; + assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName()); + assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId()); + assertEquals(QUERY, sparseVectorQueryBuilder.getQuery()); + } + + public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException { + QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields + QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY); + QueryBuilder rewritten = original.rewrite(context); + assertTrue( + "Expected query to remain sparse_vector but was [" + rewritten.getClass().getName() + "]", + rewritten instanceof SparseVectorQueryBuilder + ); + assertEquals(original, rewritten); + } + + private QueryRewriteContext createQueryRewriteContext(Map inferenceFields) { + IndexMetadata indexMetadata = IndexMetadata.builder(index.getName()) + .settings( + Settings.builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()) + .put(IndexMetadata.SETTING_INDEX_UUID, index.getUUID()) + ) + .numberOfShards(1) + .numberOfReplicas(0) + .putInferenceFields(inferenceFields) + .build(); + + ResolvedIndices resolvedIndices = new MockResolvedIndices( + Map.of(), + new OriginalIndices(new String[] { index.getName() }, IndicesOptions.DEFAULT), + Map.of(index, indexMetadata) + ); + + return new QueryRewriteContext(null, client, null, resolvedIndices, null, createRewriteInterceptor()); + } + + private QueryRewriteInterceptor createRewriteInterceptor() { + return new SemanticSparseVectorQueryRewriteInterceptor(); + } +} diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/46_semantic_text_sparse_vector.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/46_semantic_text_sparse_vector.yml new file mode 100644 index 0000000000000..f1cff512fd209 --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/46_semantic_text_sparse_vector.yml @@ -0,0 +1,249 @@ +setup: + - requires: + cluster_features: "search.semantic_sparse_vector_query_rewrite_interception_supported" + reason: semantic_text sparse_vector support introduced in 8.18.0 + + - do: + inference.put: + task_type: sparse_embedding + inference_id: sparse-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + inference.put: + task_type: sparse_embedding + inference_id: sparse-inference-id-2 + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: test-semantic-text-index + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + indices.create: + index: test-semantic-text-index-2 + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id-2 + + - do: + indices.create: + index: test-sparse-vector-index + body: + mappings: + properties: + inference_field: + type: sparse_vector + + - do: + index: + index: test-semantic-text-index + id: doc_1 + body: + inference_field: [ "inference test", "another inference test" ] + refresh: true + + - do: + index: + index: test-semantic-text-index-2 + id: doc_3 + body: + inference_field: [ "inference test", "another inference test" ] + refresh: true + + - do: + index: + index: test-sparse-vector-index + id: doc_2 + body: + inference_field: { "feature_0": 1, "feature_1": 2, "feature_2": 3, "feature_3": 4, "feature_4": 5 } + refresh: true + +--- +"Nested sparse_vector queries using the old format on semantic_text embeddings and inference still work": + - skip: + features: [ "headers" ] + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-semantic-text-index + body: + query: + nested: + path: inference_field.inference.chunks + query: + sparse_vector: + field: inference_field.inference.chunks.embeddings + inference_id: sparse-inference-id + query: test + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + +--- +"Nested sparse_vector queries using the old format on semantic_text embeddings and query vectors still work": + - skip: + features: [ "headers" ] + + - do: + headers: + # Force JSON content type so that we use a parser that interprets the floating-point score as a double + Content-Type: application/json + search: + index: test-semantic-text-index + body: + query: + nested: + path: inference_field.inference.chunks + query: + sparse_vector: + field: inference_field.inference.chunks.embeddings + query_vector: { "feature_0": 1, "feature_1": 2, "feature_2": 3, "feature_3": 4, "feature_4": 5 } + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + +--- +"sparse_vector query against semantic_text field using a specified inference ID": + + - do: + search: + index: test-semantic-text-index + body: + query: + sparse_vector: + field: inference_field + inference_id: sparse-inference-id + query: "inference test" + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + +--- +"sparse_vector query against semantic_text field using inference ID configured in semantic_text field": + + - do: + search: + index: test-semantic-text-index + body: + query: + sparse_vector: + field: inference_field + query: "inference test" + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + +--- +"sparse_vector query against semantic_text field using query vectors": + + - do: + search: + index: test-semantic-text-index + body: + query: + sparse_vector: + field: inference_field + query_vector: { "feature_0": 1, "feature_1": 2, "feature_2": 3, "feature_3": 4, "feature_4": 5 } + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + +--- +"sparse_vector query against combined sparse_vector and semantic_text fields using inference": + + - do: + search: + index: + - test-semantic-text-index + - test-sparse-vector-index + body: + query: + sparse_vector: + field: inference_field + inference_id: sparse-inference-id + query: "inference test" + + - match: { hits.total.value: 2 } + +--- +"sparse_vector query against combined sparse_vector and semantic_text fields still requires inference ID": + + - do: + catch: bad_request + search: + index: + - test-semantic-text-index + - test-sparse-vector-index + body: + query: + sparse_vector: + field: inference_field + query: "inference test" + + - match: { error.type: "illegal_argument_exception" } + - match: { error.reason: "inference_id required to perform vector search on query string" } + +--- +"sparse_vector query against combined sparse_vector and semantic_text fields using query vectors": + + - do: + search: + index: + - test-semantic-text-index + - test-sparse-vector-index + body: + query: + sparse_vector: + field: inference_field + query_vector: { "feature_0": 1, "feature_1": 2, "feature_2": 3, "feature_3": 4, "feature_4": 5 } + + - match: { hits.total.value: 2 } + + +--- +"sparse_vector query against multiple semantic_text fields with multiple inference IDs specified in semantic_text fields": + + - do: + search: + index: + - test-semantic-text-index + - test-semantic-text-index-2 + body: + query: + sparse_vector: + field: inference_field + query: "inference test" + + - match: { hits.total.value: 2 } + diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/sparse_vector_search.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/sparse_vector_search.yml index 332981a580802..3481773b0bab3 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/sparse_vector_search.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/sparse_vector_search.yml @@ -268,7 +268,7 @@ setup: - match: { hits.hits.0._score: 0.25 } --- -"Test sparse_vector requires one of inference_id or query_vector": +"Test sparse_vector requires one of query or query_vector": - do: catch: /\[sparse_vector\] requires one of \[query_vector\] or \[inference_id\]/ search: @@ -281,7 +281,41 @@ setup: - match: { status: 400 } --- -"Test sparse_vector only allows one of inference_id or query_vector": +"Test sparse_vector returns an error if inference ID not specified with query": + - do: + catch: bad_request # This is for BWC, the actual error message is tested in a subsequent test + search: + index: index-with-sparse-vector + body: + query: + sparse_vector: + field: text + query: "octopus comforter smells" + + - match: { status: 400 } + +--- +"Test sparse_vector requires an inference ID to be specified on sparse_vector fields": + - requires: + cluster_features: [ "search.semantic_sparse_vector_query_rewrite_interception_supported" ] + reason: "Error message changed in 8.18" + - do: + catch: /inference_id required to perform vector search on query string/ + search: + index: index-with-sparse-vector + body: + query: + sparse_vector: + field: text + query: "octopus comforter smells" + + - match: { status: 400 } + +--- +"Test sparse_vector only allows one of query or query_vector (note the error message is misleading)": + - requires: + cluster_features: [ "search.semantic_sparse_vector_query_rewrite_interception_supported" ] + reason: "sparse vector inference checks updated in 8.18 to support sparse_vector on semantic_text fields" - do: catch: /\[sparse_vector\] requires one of \[query_vector\] or \[inference_id\]/ search: @@ -290,7 +324,7 @@ setup: query: sparse_vector: field: text - inference_id: text_expansion_model + query: "octopus comforter smells" query_vector: the: 1.0 comforter: 1.0