Skip to content

Commit

Permalink
Add support for sparse_vector queries against semantic_text fields (e…
Browse files Browse the repository at this point in the history
…lastic#118617) (elastic#118951)

(cherry picked from commit 15bec3c)

# Conflicts:
#	x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java
  • Loading branch information
kderusso authored Dec 18, 2024
1 parent 226ff67 commit 19fd296
Show file tree
Hide file tree
Showing 13 changed files with 887 additions and 76 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/118617.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 118617
summary: Add support for `sparse_vector` queries against `semantic_text` fields
area: "Search"
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -90,26 +90,33 @@ public SparseVectorQueryBuilder(
: (this.shouldPruneTokens ? new TokenPruningConfig() : null));
this.weightedTokensSupplier = null;

if (queryVectors == null ^ inferenceId == null == false) {
// Preserve BWC error messaging
if (queryVectors != null && inferenceId != null) {
throw new IllegalArgumentException(
"["
+ NAME
+ "] requires one of ["
+ QUERY_VECTOR_FIELD.getPreferredName()
+ "] or ["
+ INFERENCE_ID_FIELD.getPreferredName()
+ "]"
+ "] for "
+ ALLOWED_FIELD_TYPE
+ " fields"
);
}
if (inferenceId != null && query == null) {

// Preserve BWC error messaging
if ((queryVectors == null) == (query == null)) {
throw new IllegalArgumentException(
"["
+ NAME
+ "] requires ["
+ QUERY_FIELD.getPreferredName()
+ "] when ["
+ "] requires one of ["
+ QUERY_VECTOR_FIELD.getPreferredName()
+ "] or ["
+ INFERENCE_ID_FIELD.getPreferredName()
+ "] is specified"
+ "] for "
+ ALLOWED_FIELD_TYPE
+ " fields"
);
}
}
Expand Down Expand Up @@ -143,6 +150,14 @@ public List<WeightedToken> getQueryVectors() {
return queryVectors;
}

public String getInferenceId() {
return inferenceId;
}

public String getQuery() {
return query;
}

public boolean shouldPruneTokens() {
return shouldPruneTokens;
}
Expand Down Expand Up @@ -176,7 +191,9 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
}
builder.endObject();
} else {
builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
if (inferenceId != null) {
builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
}
builder.field(QUERY_FIELD.getPreferredName(), query);
}
builder.field(PRUNE_FIELD.getPreferredName(), shouldPruneTokens);
Expand Down Expand Up @@ -228,6 +245,11 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
shouldPruneTokens,
tokenPruningConfig
);
} else if (inferenceId == null) {
// Edge case, where inference_id was not specified in the request,
// but we did not intercept this and rewrite to a query o field with
// pre-configured inference. So we trap here and output a nicer error message.
throw new IllegalArgumentException("inference_id required to perform vector search on query string");
}

// TODO move this to xpack core and use inference APIs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,16 +260,16 @@ public void testIllegalValues() {
{
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> new SparseVectorQueryBuilder("field name", null, "model id")
() -> new SparseVectorQueryBuilder("field name", null, null)
);
assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id]", e.getMessage());
assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id] for sparse_vector fields", e.getMessage());
}
{
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> new SparseVectorQueryBuilder("field name", "model text", null)
);
assertEquals("[sparse_vector] requires [query] when [inference_id] is specified", e.getMessage());
assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id] for sparse_vector fields", e.getMessage());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
import org.elasticsearch.features.FeatureSpecification;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;

import java.util.Set;

import static org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
import static org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor.SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED;

/**
* Provides inference features.
*/
Expand Down Expand Up @@ -43,7 +45,8 @@ public Set<NodeFeature> getTestFeatures() {
SemanticTextFieldMapper.SEMANTIC_TEXT_ZERO_SIZE_FIX,
SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX,
SEMANTIC_TEXT_HIGHLIGHTER,
SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED
SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED,
SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder;
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder;
Expand Down Expand Up @@ -404,7 +405,7 @@ public List<QuerySpec<?>> getQueries() {

@Override
public List<QueryRewriteInterceptor> getQueryRewriteInterceptors() {
return List.of(new SemanticMatchQueryRewriteInterceptor());
return List.of(new SemanticMatchQueryRewriteInterceptor(), new SemanticSparseVectorQueryRewriteInterceptor());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,12 @@

package org.elasticsearch.xpack.inference.queries;

import org.elasticsearch.action.ResolvedIndices;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.mapper.IndexFieldMapper;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.MatchQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterceptor {
public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {

public static final NodeFeature SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
"search.semantic_match_query_rewrite_interception_supported"
Expand All @@ -33,63 +21,45 @@ public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterce
public SemanticMatchQueryRewriteInterceptor() {}

@Override
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
protected String getFieldName(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
QueryBuilder rewritten = queryBuilder;
ResolvedIndices resolvedIndices = context.getResolvedIndices();
if (resolvedIndices != null) {
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
List<String> inferenceIndices = new ArrayList<>();
List<String> nonInferenceIndices = new ArrayList<>();
for (IndexMetadata indexMetadata : indexMetadataCollection) {
String indexName = indexMetadata.getIndex().getName();
InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(matchQueryBuilder.fieldName());
if (inferenceFieldMetadata != null) {
inferenceIndices.add(indexName);
} else {
nonInferenceIndices.add(indexName);
}
}

if (inferenceIndices.isEmpty()) {
return rewritten;
} else if (nonInferenceIndices.isEmpty() == false) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
for (String inferenceIndexName : inferenceIndices) {
// Add a separate clause for each semantic query, because they may be using different inference endpoints
// TODO - consolidate this to a single clause once the semantic query supports multiple inference endpoints
boolQueryBuilder.should(
createSemanticSubQuery(inferenceIndexName, matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value())
);
}
boolQueryBuilder.should(createMatchSubQuery(nonInferenceIndices, matchQueryBuilder));
rewritten = boolQueryBuilder;
} else {
rewritten = new SemanticQueryBuilder(matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value(), false);
}
}

return rewritten;
return matchQueryBuilder.fieldName();
}

@Override
protected String getQuery(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
return (String) matchQueryBuilder.value();
}

@Override
public String getQueryName() {
return MatchQueryBuilder.NAME;
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
return new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
}

private QueryBuilder createSemanticSubQuery(String indexName, String fieldName, String value) {
@Override
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation
) {
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true));
boolQueryBuilder.filter(new TermQueryBuilder(IndexFieldMapper.NAME, indexName));
boolQueryBuilder.should(
createSemanticSubQuery(
indexInformation.getInferenceIndices(),
matchQueryBuilder.fieldName(),
(String) matchQueryBuilder.value()
)
);
boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), matchQueryBuilder));
return boolQueryBuilder;
}

private QueryBuilder createMatchSubQuery(List<String> indices, MatchQueryBuilder matchQueryBuilder) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(matchQueryBuilder);
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
return boolQueryBuilder;
@Override
public String getQueryName() {
return MatchQueryBuilder.NAME;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ public String getWriteableName() {
return NAME;
}

public String getFieldName() {
return fieldName;
}

public String getQuery() {
return query;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.V_8_15_0;
Expand Down
Loading

0 comments on commit 19fd296

Please sign in to comment.