Skip to content

Commit

Permalink
Add a rule query retriever reusing the compound builder
Browse files Browse the repository at this point in the history
  • Loading branch information
jimczi committed Oct 16, 2024
1 parent bd23441 commit 3766d7c
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.TransportMultiSearchAction;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.search.builder.PointInTimeBuilder;
Expand Down Expand Up @@ -163,6 +162,11 @@ public final QueryBuilder topDocsQuery() {
throw new IllegalStateException("Should not be called, missing a rewrite?");
}

@Override
public final QueryBuilder explainQuery() {
throw new IllegalStateException("Should not be called, missing a rewrite?");
}

@Override
public final void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
throw new IllegalStateException("Should not be called, missing a rewrite?");
Expand Down Expand Up @@ -218,30 +222,14 @@ protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit,
}
retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true);

// apply the pre-filters
if (preFilterQueryBuilders.size() > 0) {
QueryBuilder query = sourceBuilder.query();
BoolQueryBuilder newQuery = new BoolQueryBuilder();
if (query != null) {
newQuery.must(query);
}
preFilterQueryBuilders.forEach(newQuery::filter);
sourceBuilder.query(newQuery);
}

addSort(sourceBuilder);

return sourceBuilder;
}

protected void addSort(SearchSourceBuilder sourceBuilder) {
// Record the shard id in the sort result
List<SortBuilder<?>> sortBuilders = sourceBuilder.sorts() != null ? new ArrayList<>(sourceBuilder.sorts()) : new ArrayList<>();
if (sortBuilders.isEmpty()) {
sortBuilders.add(new ScoreSortBuilder());
}
sortBuilders.add(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME));
sourceBuilder.sort(sortBuilders);
return sourceBuilder;
}

private RankDoc[] getRankDocs(SearchResponse searchResponse) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.retriever;

import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.List;

/**
* A wrapper that can be used to modify the behaviour of an existing {@link RetrieverBuilder}.
*/
public abstract class RetrieverBuilderWrapper<T extends RetrieverBuilder> extends RetrieverBuilder {
protected final RetrieverBuilder in;

protected RetrieverBuilderWrapper(RetrieverBuilder in) {
this.in = in;
}

protected abstract T clone(RetrieverBuilder sub);

@Override
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
var inRewrite = in.rewrite(ctx);
if (inRewrite != in) {
return clone(inRewrite);
}
return this;
}

@Override
public QueryBuilder topDocsQuery() {
return in.topDocsQuery();
}

@Override
public RetrieverBuilder minScore(Float minScore) {
return in.minScore(minScore);
}

@Override
public List<QueryBuilder> getPreFilterQueryBuilders() {
return in.preFilterQueryBuilders;
}

@Override
public ActionRequestValidationException validate(
SearchSourceBuilder source,
ActionRequestValidationException validationException,
boolean allowPartialSearchResults
) {
return in.validate(source, validationException, allowPartialSearchResults);
}

@Override
public RetrieverBuilder retrieverName(String retrieverName) {
return in.retrieverName(retrieverName);
}

@Override
public void setRankDocs(RankDoc[] rankDocs) {
in.setRankDocs(rankDocs);
}

@Override
public boolean isCompound() {
return in.isCompound();
}

@Override
public QueryBuilder explainQuery() {
return in.explainQuery();
}

@Override
public Float minScore() {
return in.minScore();
}

@Override
public boolean isFragment() {
return in.isFragment();
}

@Override
public String toString() {
return in.toString();
}

@Override
public String retrieverName() {
return in.retrieverName();
}

@Override
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
in.extractToSearchSourceBuilder(searchSourceBuilder, compoundUsed);
}

@Override
public String getName() {
return in.getName();
}

@Override
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
in.doToXContent(builder, params);
}

@Override
protected boolean doEquals(Object o) {
return in.equals(o);
}

@Override
protected int doHashCode() {
return in.doHashCode();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.StoredFieldsContext;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilderWrapper;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.search.sort.ScoreSortBuilder;
import org.elasticsearch.search.sort.SortBuilder;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
Expand All @@ -29,6 +31,7 @@
import org.elasticsearch.xpack.core.XPackPlugin;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -93,9 +96,10 @@ public QueryRuleRetrieverBuilder(
RetrieverBuilder retrieverBuilder,
int rankWindowSize
) {
super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize);
super(new ArrayList<>(), rankWindowSize);
this.rulesetIds = rulesetIds;
this.matchCriteria = matchCriteria;
addChild(new QueryRuleRetrieverBuilderWrapper(retrieverBuilder));
}

public QueryRuleRetrieverBuilder(
Expand All @@ -118,21 +122,20 @@ public String getName() {

@Override
protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
.trackTotalHits(false)
.storedFields(new StoredFieldsContext(false))
.size(rankWindowSize);
retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true);

QueryBuilder query = sourceBuilder.query();
if (query != null && query instanceof RuleQueryBuilder == false) {
QueryBuilder ruleQuery = new RuleQueryBuilder(query, matchCriteria, rulesetIds);
sourceBuilder.query(ruleQuery);
}
var ret = super.createSearchSourceBuilder(pit, retrieverBuilder);
checkValidSort(ret.sorts());
ret.query(new RuleQueryBuilder(ret.query(), matchCriteria, rulesetIds));
return ret;
}

addSort(sourceBuilder);
private static void checkValidSort(List<SortBuilder<?>> sortBuilders) {
if (sortBuilders.isEmpty()) {
return;
}

return sourceBuilder;
if (sortBuilders.size() > 1 || sortBuilders.get(0) instanceof ScoreSortBuilder == false) {
throw new IllegalArgumentException("Rule retrievers can only sort documents by relevance score, got: " + sortBuilders);
}
}

@Override
Expand Down Expand Up @@ -160,12 +163,6 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
return rankDocs;
}

@Override
public QueryBuilder explainQuery() {
// the original matching set of the QueryRuleRetriever retriever is specified by its nested retriever
return new RankDocsQueryBuilder(rankDocs, new QueryBuilder[] { innerRetrievers.getFirst().retriever().explainQuery() }, true);
}

@Override
public boolean doEquals(Object o) {
QueryRuleRetrieverBuilder that = (QueryRuleRetrieverBuilder) o;
Expand All @@ -176,4 +173,29 @@ public boolean doEquals(Object o) {
public int doHashCode() {
return Objects.hash(super.doHashCode(), rulesetIds, matchCriteria);
}

class QueryRuleRetrieverBuilderWrapper extends RetrieverBuilderWrapper<QueryRuleRetrieverBuilderWrapper> {
protected QueryRuleRetrieverBuilderWrapper(RetrieverBuilder sub) {
super(sub);
}

@Override
protected QueryRuleRetrieverBuilderWrapper clone(RetrieverBuilder sub) {
return new QueryRuleRetrieverBuilderWrapper(sub);
}

@Override
public QueryBuilder topDocsQuery() {
return new RuleQueryBuilder(in.topDocsQuery(), matchCriteria, rulesetIds);
}

@Override
public QueryBuilder explainQuery() {
return new RankDocsQueryBuilder(
rankDocs,
new QueryBuilder[] { new RuleQueryBuilder(in.explainQuery(), matchCriteria, rulesetIds) },
true
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.builder.PointInTimeBuilder;
Expand All @@ -20,7 +19,6 @@
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
Expand Down Expand Up @@ -158,12 +156,6 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
return textSimilarityRankDocs;
}

@Override
public QueryBuilder explainQuery() {
// the original matching set of the TextSimilarityRank retriever is specified by its nested retriever
return new RankDocsQueryBuilder(rankDocs, new QueryBuilder[] { innerRetrievers.getFirst().retriever().explainQuery() }, true);
}

@Override
protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
Expand All @@ -175,16 +167,6 @@ protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit,
}
retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true);

// apply the pre-filters
if (preFilterQueryBuilders.size() > 0) {
QueryBuilder query = sourceBuilder.query();
BoolQueryBuilder newQuery = new BoolQueryBuilder();
if (query != null) {
newQuery.must(query);
}
preFilterQueryBuilders.forEach(newQuery::filter);
sourceBuilder.query(newQuery);
}
sourceBuilder.rankBuilder(
new TextSimilarityRankBuilder(this.field, this.inferenceId, this.inferenceText, this.rankWindowSize, this.minScore)
);
Expand Down

0 comments on commit 3766d7c

Please sign in to comment.