Skip to content

Commit

Permalink
Add prefilters only once in the compound and text similarity retrievers
Browse files Browse the repository at this point in the history
This change ensures that the prefilters are propagated in the downstream retrievers only once.
It also removes the ability to extends `explainQuery` in the compound retriever. This is not needed
as the rank docs are now responsible for the explanation.
  • Loading branch information
jimczi committed Oct 17, 2024
1 parent 32ddbb3 commit ac3ca31
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ public final QueryBuilder topDocsQuery() {
throw new IllegalStateException("Should not be called, missing a rewrite?");
}

@Override
public final QueryBuilder explainQuery() {
throw new IllegalStateException("Should not be called, missing a rewrite?");
}

@Override
public final void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
throw new IllegalStateException("Should not be called, missing a rewrite?");
Expand Down Expand Up @@ -216,22 +221,12 @@ protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit,
.trackTotalHits(false)
.storedFields(new StoredFieldsContext(false))
.size(rankWindowSize);
// apply the pre-filters downstream once
if (preFilterQueryBuilders.isEmpty() == false) {
retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
}
retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true);

// apply the pre-filters
if (preFilterQueryBuilders.size() > 0) {
QueryBuilder query = sourceBuilder.query();
BoolQueryBuilder newQuery = new BoolQueryBuilder();
if (query != null) {
newQuery.must(query);
}
preFilterQueryBuilders.forEach(newQuery::filter);
sourceBuilder.query(newQuery);
}

// Record the shard id in the sort result
List<SortBuilder<?>> sortBuilders = sourceBuilder.sorts() != null ? new ArrayList<>(sourceBuilder.sorts()) : new ArrayList<>();
if (sortBuilders.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.builder.PointInTimeBuilder;
Expand All @@ -20,7 +19,6 @@
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
Expand Down Expand Up @@ -158,33 +156,18 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
return textSimilarityRankDocs;
}

@Override
public QueryBuilder explainQuery() {
// the original matching set of the TextSimilarityRank retriever is specified by its nested retriever
return new RankDocsQueryBuilder(rankDocs, new QueryBuilder[] { innerRetrievers.getFirst().retriever().explainQuery() }, true);
}

@Override
protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
.trackTotalHits(false)
.storedFields(new StoredFieldsContext(false))
.size(rankWindowSize);
// apply the pre-filters downstream once
if (preFilterQueryBuilders.isEmpty() == false) {
retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
}
retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true);

// apply the pre-filters
if (preFilterQueryBuilders.size() > 0) {
QueryBuilder query = sourceBuilder.query();
BoolQueryBuilder newQuery = new BoolQueryBuilder();
if (query != null) {
newQuery.must(query);
}
preFilterQueryBuilders.forEach(newQuery::filter);
sourceBuilder.query(newQuery);
}
sourceBuilder.rankBuilder(
new TextSimilarityRankBuilder(this.field, this.inferenceId, this.inferenceText, this.rankWindowSize, this.minScore)
);
Expand Down

0 comments on commit ac3ca31

Please sign in to comment.