Skip to content

Commit

Permalink
Fix for propagating filters from compound to inner retrievers
Browse files Browse the repository at this point in the history
  • Loading branch information
pmpailis committed Dec 5, 2024
1 parent a074337 commit 03c1f75
Show file tree
Hide file tree
Showing 12 changed files with 174 additions and 57 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/117914.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 117914
summary: Fix for propagating filters from compound to inner retrievers
area: Ranking
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.TransportMultiSearchAction;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.rest.RestStatus;
Expand All @@ -47,6 +47,8 @@
*/
public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilder<T>> extends RetrieverBuilder {

public static final NodeFeature INNER_RETRIEVERS_FILTER_SUPPORT = new NodeFeature("inner_retrievers_filter_support");

public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {}

protected final int rankWindowSize;
Expand All @@ -65,9 +67,9 @@ public T addChild(RetrieverBuilder retrieverBuilder) {

/**
* Returns a clone of the original retriever, replacing the sub-retrievers with
* the provided {@code newChildRetrievers}.
* the provided {@code newChildRetrievers} and the filters with the {@code newPreFilterQueryBuilders}.
*/
protected abstract T clone(List<RetrieverSource> newChildRetrievers);
protected abstract T clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders);

/**
* Combines the provided {@code rankResults} to return the final top documents.
Expand All @@ -86,13 +88,25 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio
}

// Rewrite prefilters
boolean hasChanged = false;
// We eagerly rewrite prefilters, because some of the innerRetrievers
// could be compound too, so we want to propagate all the necessary filter information to them
// and have it available as part of their own rewrite step
var newPreFilters = rewritePreFilters(ctx);
hasChanged |= newPreFilters != preFilterQueryBuilders;
if (newPreFilters != preFilterQueryBuilders) {
return clone(innerRetrievers, newPreFilters);
}

boolean hasChanged = false;
// Rewrite retriever sources
List<RetrieverSource> newRetrievers = new ArrayList<>();
for (var entry : innerRetrievers) {
// we propagate the filters only for compound retrievers as they won't be attached through
// the createSearchSourceBuilder.
// We could remove this check, but we would end up adding the same filters
// multiple times in case an inner retriever rewrites itself, when we re-enter to rewrite
if (entry.retriever.isCompound() && false == preFilterQueryBuilders.isEmpty()) {
entry.retriever.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
}
RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx);
if (newRetriever != entry.retriever) {
newRetrievers.add(new RetrieverSource(newRetriever, null));
Expand All @@ -107,7 +121,7 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio
}
}
if (hasChanged) {
return clone(newRetrievers);
return clone(newRetrievers, newPreFilters);
}

// execute searches
Expand Down Expand Up @@ -167,19 +181,19 @@ public void onFailure(Exception e) {
});
});

return new RankDocsRetrieverBuilder(
rankWindowSize,
newRetrievers.stream().map(s -> s.retriever).toList(),
results::get,
newPreFilters
);
return new RankDocsRetrieverBuilder(rankWindowSize, newRetrievers.stream().map(s -> s.retriever).toList(), results::get);
}

@Override
public final QueryBuilder topDocsQuery() {
throw new IllegalStateException("Should not be called, missing a rewrite?");
}

@Override
public final QueryBuilder explainQuery() {
throw new IllegalStateException("Should not be called, missing a rewrite?");
}

@Override
public final void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
throw new IllegalStateException("Should not be called, missing a rewrite?");
Expand Down Expand Up @@ -237,22 +251,12 @@ protected final SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder
.trackTotalHits(false)
.storedFields(new StoredFieldsContext(false))
.size(rankWindowSize);
// apply the pre-filters downstream once
if (preFilterQueryBuilders.isEmpty() == false) {
retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
}
retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true);

// apply the pre-filters
if (preFilterQueryBuilders.size() > 0) {
QueryBuilder query = sourceBuilder.query();
BoolQueryBuilder newQuery = new BoolQueryBuilder();
if (query != null) {
newQuery.must(query);
}
preFilterQueryBuilders.forEach(newQuery::filter);
sourceBuilder.query(newQuery);
}

// Record the shard id in the sort result
List<SortBuilder<?>> sortBuilders = sourceBuilder.sorts() != null ? new ArrayList<>(sourceBuilder.sorts()) : new ArrayList<>();
if (sortBuilders.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,7 @@ public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
ll.onResponse(null);
}));
});
var rewritten = new KnnRetrieverBuilder(this, () -> toSet.get(), null);
return rewritten;
return new KnnRetrieverBuilder(this, () -> toSet.get(), null);
}
return super.rewrite(ctx);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,13 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
final List<RetrieverBuilder> sources;
final Supplier<RankDoc[]> rankDocs;

public RankDocsRetrieverBuilder(
int rankWindowSize,
List<RetrieverBuilder> sources,
Supplier<RankDoc[]> rankDocs,
List<QueryBuilder> preFilterQueryBuilders
) {
public RankDocsRetrieverBuilder(int rankWindowSize, List<RetrieverBuilder> sources, Supplier<RankDoc[]> rankDocs) {
this.rankWindowSize = rankWindowSize;
this.rankDocs = rankDocs;
if (sources == null || sources.isEmpty()) {
throw new IllegalArgumentException("sources must not be null or empty");
}
this.sources = sources;
this.preFilterQueryBuilders = preFilterQueryBuilders;
}

@Override
Expand Down Expand Up @@ -73,10 +67,6 @@ private boolean sourceShouldRewrite(QueryRewriteContext ctx) throws IOException
@Override
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
assert false == sourceShouldRewrite(ctx) : "retriever sources should be rewritten first";
var rewrittenFilters = rewritePreFilters(ctx);
if (rewrittenFilters != preFilterQueryBuilders) {
return new RankDocsRetrieverBuilder(rankWindowSize, sources, rankDocs, rewrittenFilters);
}
return this;
}

Expand All @@ -94,7 +84,7 @@ public QueryBuilder topDocsQuery() {
boolQuery.should(query);
}
}
// ignore prefilters of this level, they are already propagated to children
// ignore prefilters of this level, they were already propagated to children
return boolQuery;
}

Expand Down Expand Up @@ -133,7 +123,7 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
} else {
rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false);
}
// ignore prefilters of this level, they are already propagated to children
// ignore prefilters of this level, they were already propagated to children
searchSourceBuilder.query(rankQuery);
if (sourceHasMinScore()) {
searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,7 @@ private List<QueryBuilder> preFilters(QueryRewriteContext queryRewriteContext) t
}

private RankDocsRetrieverBuilder createRandomRankDocsRetrieverBuilder(QueryRewriteContext queryRewriteContext) throws IOException {
return new RankDocsRetrieverBuilder(
randomIntBetween(1, 100),
innerRetrievers(queryRewriteContext),
rankDocsSupplier(),
preFilters(queryRewriteContext)
);
return new RankDocsRetrieverBuilder(randomIntBetween(1, 100), innerRetrievers(queryRewriteContext), rankDocsSupplier());
}

public void testExtractToSearchSourceBuilder() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
/**
* A SearchPlugin to exercise query vector builder
*/
class TestQueryVectorBuilderPlugin implements SearchPlugin {
public class TestQueryVectorBuilderPlugin implements SearchPlugin {

static class TestQueryVectorBuilder implements QueryVectorBuilder {
public static class TestQueryVectorBuilder implements QueryVectorBuilder {
private static final String NAME = "test_query_vector_builder";

private static final ParseField QUERY_VECTOR = new ParseField("query_vector");
Expand All @@ -47,11 +47,11 @@ static class TestQueryVectorBuilder implements QueryVectorBuilder {

private List<Float> vectorToBuild;

TestQueryVectorBuilder(List<Float> vectorToBuild) {
public TestQueryVectorBuilder(List<Float> vectorToBuild) {
this.vectorToBuild = vectorToBuild;
}

TestQueryVectorBuilder(float[] expected) {
public TestQueryVectorBuilder(float[] expected) {
this.vectorToBuild = new ArrayList<>(expected.length);
for (float f : expected) {
vectorToBuild.add(f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package org.elasticsearch.search.retriever;

import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.xcontent.XContentBuilder;

Expand All @@ -23,16 +24,17 @@ public class TestCompoundRetrieverBuilder extends CompoundRetrieverBuilder<TestC
public static final String NAME = "test_compound_retriever_builder";

public TestCompoundRetrieverBuilder(int rankWindowSize) {
this(new ArrayList<>(), rankWindowSize);
this(new ArrayList<>(), rankWindowSize, new ArrayList<>());
}

TestCompoundRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize) {
TestCompoundRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, List<QueryBuilder> preFilterQueryBuilders) {
super(childRetrievers, rankWindowSize);
this.preFilterQueryBuilders = preFilterQueryBuilders;
}

@Override
protected TestCompoundRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
return new TestCompoundRetrieverBuilder(newChildRetrievers, rankWindowSize);
protected TestCompoundRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
return new TestCompoundRetrieverBuilder(newChildRetrievers, rankWindowSize, newPreFilterQueryBuilders);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,10 @@ public TextSimilarityRankRetrieverBuilder(
}

@Override
protected TextSimilarityRankRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
protected TextSimilarityRankRetrieverBuilder clone(
List<RetrieverSource> newChildRetrievers,
List<QueryBuilder> newPreFilterQueryBuilders
) {
return new TextSimilarityRankRetrieverBuilder(
newChildRetrievers,
inferenceId,
Expand All @@ -139,7 +142,7 @@ protected TextSimilarityRankRetrieverBuilder clone(List<RetrieverSource> newChil
rankWindowSize,
minScore,
retrieverName,
preFilterQueryBuilders
newPreFilterQueryBuilders
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.search.vectors.TestQueryVectorBuilderPlugin;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
import org.elasticsearch.xcontent.XContentBuilder;
Expand All @@ -57,7 +58,6 @@
public class RRFRetrieverBuilderIT extends ESIntegTestCase {

protected static String INDEX = "test_index";
protected static final String ID_FIELD = "_id";
protected static final String DOC_FIELD = "doc";
protected static final String TEXT_FIELD = "text";
protected static final String VECTOR_FIELD = "vector";
Expand Down Expand Up @@ -743,6 +743,42 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
}

public void testRRFFiltersPropagatedToKnnQueryVectorBuilder() {
final int rankWindowSize = 100;
final int rankConstant = 10;
SearchSourceBuilder source = new SearchSourceBuilder();
// this will retriever all but 7 only due to top-level filter
StandardRetrieverBuilder standardRetriever = new StandardRetrieverBuilder(QueryBuilders.matchAllQuery());
// this will too retrieve just doc 7
KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder(
"vector",
null,
new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[] { 3 }),
10,
10,
null
);
source.retriever(
new RRFRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(standardRetriever, null),
new CompoundRetrieverBuilder.RetrieverSource(knnRetriever, null)
),
rankWindowSize,
rankConstant
)
);
source.retriever().getPreFilterQueryBuilders().add(QueryBuilders.boolQuery().must(QueryBuilders.termQuery(DOC_FIELD, "doc_7")));
source.size(10);
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertNotNull(resp.getHits().getTotalHits());
assertThat(resp.getHits().getTotalHits().value(), equalTo(1L));
assertThat(resp.getHits().getHits()[0].getId(), equalTo("doc_7"));
});
}

public void testRewriteOnce() {
final float[] vector = new float[] { 1 };
AtomicInteger numAsyncCalls = new AtomicInteger();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import java.util.Set;

import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.INNER_RETRIEVERS_FILTER_SUPPORT;
import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverBuilder.RRF_RETRIEVER_COMPOSITION_SUPPORTED;

/**
Expand All @@ -23,4 +24,9 @@ public class RRFFeatures implements FeatureSpecification {
public Set<NodeFeature> getFeatures() {
return Set.of(RRFRetrieverBuilder.RRF_RETRIEVER_SUPPORTED, RRF_RETRIEVER_COMPOSITION_SUPPORTED);
}

@Override
public Set<NodeFeature> getTestFeatures() {
return Set.of(INNER_RETRIEVERS_FILTER_SUPPORT);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.RankDoc;
Expand Down Expand Up @@ -108,8 +109,10 @@ public String getName() {
}

@Override
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers) {
return new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
return clone;
}

@Override
Expand Down
Loading

0 comments on commit 03c1f75

Please sign in to comment.