Skip to content

Commit

Permalink
Ensure that all rewriteable are called in retrievers
Browse files Browse the repository at this point in the history
This PR ensures that all retriever applies the rewrite to all their rewriteable.
Rewriting eagerly at the retriever level ensures that we don't rewrite the same query multiple times
when compound retrievers are used.
  • Loading branch information
jimczi committed Oct 8, 2024
1 parent 2ba9bc9 commit f9771dd
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@

package org.elasticsearch.search.retriever;

import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
Expand All @@ -29,7 +31,9 @@
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

import static org.elasticsearch.common.Strings.format;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;

Expand Down Expand Up @@ -96,7 +100,7 @@ public static KnnRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
}

private final String field;
private final float[] queryVector;
private final Supplier<float[]> queryVector;
private final QueryVectorBuilder queryVectorBuilder;
private final int k;
private final int numCands;
Expand All @@ -111,20 +115,64 @@ public KnnRetrieverBuilder(
Float similarity
) {
this.field = field;
this.queryVector = queryVector;
this.queryVector = () -> queryVector;
this.queryVectorBuilder = queryVectorBuilder;
this.k = k;
this.numCands = numCands;
this.similarity = similarity;
}

// ---- FOR TESTING XCONTENT PARSING ----
private KnnRetrieverBuilder(KnnRetrieverBuilder clone, Supplier<float[]> queryVector, QueryVectorBuilder queryVectorBuilder) {
this.queryVector = queryVector;
this.queryVectorBuilder = queryVectorBuilder;
this.field = clone.field;
this.k = clone.k;
this.numCands = clone.numCands;
this.similarity = clone.similarity;
this.retrieverName = clone.retrieverName;
this.preFilterQueryBuilders = clone.preFilterQueryBuilders;
}

@Override
public String getName() {
return NAME;
}

@Override
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
var rewrittenFilters = rewritePreFilters(ctx);
if (rewrittenFilters != preFilterQueryBuilders) {
var rewritten = new KnnRetrieverBuilder(this, queryVector, queryVectorBuilder);
rewritten.preFilterQueryBuilders = rewrittenFilters;
return rewritten;
}

if (queryVectorBuilder != null) {
SetOnce<float[]> toSet = new SetOnce<>();
ctx.registerAsyncAction((c, l) -> {
queryVectorBuilder.buildVector(c, l.delegateFailureAndWrap((ll, v) -> {
toSet.set(v);
if (v == null) {
ll.onFailure(
new IllegalArgumentException(
format(
"[%s] with name [%s] returned null query_vector",
QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
queryVectorBuilder.getWriteableName()
)
)
);
return;
}
ll.onResponse(null);
}));
});
var rewritten = new KnnRetrieverBuilder(this, () -> toSet.get(), null);
return rewritten;
}
return super.rewrite(ctx);
}

@Override
public QueryBuilder topDocsQuery() {
assert rankDocs != null : "rankDocs should have been materialized by now";
Expand All @@ -142,7 +190,7 @@ public QueryBuilder explainQuery() {
assert rankDocs != null : "rankDocs should have been materialized by now";
var rankDocsQuery = new RankDocsQueryBuilder(
rankDocs,
new QueryBuilder[] { new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), field, similarity) },
new QueryBuilder[] { new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector.get()), field, similarity) },
true
);
if (preFilterQueryBuilders.isEmpty()) {
Expand All @@ -157,7 +205,7 @@ public QueryBuilder explainQuery() {
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder(
field,
VectorData.fromFloats(queryVector),
VectorData.fromFloats(queryVector.get()),
queryVectorBuilder,
k,
numCands,
Expand All @@ -174,6 +222,7 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
searchSourceBuilder.knnSearch(knnSearchBuilders);
}

// ---- FOR TESTING XCONTENT PARSING ----
@Override
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(FIELD_FIELD.getPreferredName(), field);
Expand All @@ -199,15 +248,15 @@ public boolean doEquals(Object o) {
return k == that.k
&& numCands == that.numCands
&& Objects.equals(field, that.field)
&& Arrays.equals(queryVector, that.queryVector)
&& Arrays.equals(queryVector.get(), that.queryVector.get())
&& Objects.equals(queryVectorBuilder, that.queryVectorBuilder)
&& Objects.equals(similarity, that.similarity);
}

@Override
public int doHashCode() {
int result = Objects.hash(field, queryVectorBuilder, k, numCands, similarity);
result = 31 * result + Arrays.hashCode(queryVector);
result = 31 * result + Arrays.hashCode(queryVector.get());
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.builder.SubSearchSourceBuilder;
import org.elasticsearch.search.collapse.CollapseBuilder;
Expand All @@ -27,6 +28,7 @@
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

Expand Down Expand Up @@ -105,6 +107,43 @@ public StandardRetrieverBuilder(QueryBuilder queryBuilder) {
this.queryBuilder = queryBuilder;
}

private StandardRetrieverBuilder(StandardRetrieverBuilder clone) {
this.retrieverName = clone.retrieverName;
this.queryBuilder = clone.queryBuilder;
this.minScore = clone.minScore;
this.sortBuilders = clone.sortBuilders;
this.preFilterQueryBuilders = clone.preFilterQueryBuilders;
this.collapseBuilder = clone.collapseBuilder;
this.searchAfterBuilder = clone.searchAfterBuilder;
this.terminateAfter = clone.terminateAfter;
}

@Override
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
boolean changed = false;
List<SortBuilder<?>> newSortBuilders = null;
if (sortBuilders != null) {
newSortBuilders = new ArrayList<>(sortBuilders.size());
for (var sort : sortBuilders) {
changed |= newSortBuilders.add(sort.rewrite(ctx));
}
}
var rewrittenFilters = rewritePreFilters(ctx);
changed |= rewrittenFilters != preFilterQueryBuilders;

var queryBuilderRewrite = queryBuilder.rewrite(ctx);
changed |= queryBuilderRewrite != queryBuilder;

if (changed) {
var rewritten = new StandardRetrieverBuilder(this);
rewritten.sortBuilders = newSortBuilders;
rewritten.preFilterQueryBuilders = preFilterQueryBuilders;
rewritten.queryBuilder = queryBuilderRewrite;
return rewritten;
}
return this;
}

@Override
public QueryBuilder topDocsQuery() {
if (preFilterQueryBuilders.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.xcontent.XContentBuilder;
Expand Down Expand Up @@ -54,6 +55,22 @@ public RankDocsQueryBuilder(StreamInput in) throws IOException {
}
}

@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
if (queryBuilders != null) {
QueryBuilder[] newQueryBuilders = new QueryBuilder[queryBuilders.length];
boolean changed = false;
for (int i = 0; i < newQueryBuilders.length; i++) {
newQueryBuilders[i] = queryBuilders[i].rewrite(queryRewriteContext);
changed |= newQueryBuilders[i] != queryBuilders[i];
}
if (changed) {
return new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs);
}
}
return super.doRewrite(queryRewriteContext);
}

RankDoc[] rankDocs() {
return rankDocs;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer k, I
this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, vectorSimilarity);
}

protected KnnVectorQueryBuilder(
public KnnVectorQueryBuilder(
String fieldName,
QueryVectorBuilder queryVectorBuilder,
Integer k,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
package org.elasticsearch.xpack.rank.rrf;

import org.apache.lucene.search.TotalHits;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.InnerHitBuilder;
import org.elasticsearch.index.query.QueryBuilder;
Expand All @@ -24,16 +28,23 @@
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentType;
import org.junit.Before;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
Expand Down Expand Up @@ -652,4 +663,55 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD));
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
}

public void testRewriteOnce() {
final float[] vector = new float[] { 1 };
AtomicInteger numAsyncCalls = new AtomicInteger();
QueryVectorBuilder vectorBuilder = new QueryVectorBuilder() {
@Override
public void buildVector(Client client, ActionListener<float[]> listener) {
numAsyncCalls.incrementAndGet();
listener.onResponse(vector);
}

@Override
public String getWriteableName() {
throw new IllegalStateException("Should not be called");
}

@Override
public TransportVersion getMinimalSupportedVersion() {
throw new IllegalStateException("Should not be called");
}

@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IllegalStateException("Should not be called");
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
throw new IllegalStateException("Should not be called");
}
};
var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null);
var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null));
var rrf = new RRFRetrieverBuilder(
List.of(
new CompoundRetrieverBuilder.RetrieverSource(knn, null),
new CompoundRetrieverBuilder.RetrieverSource(standard, null)
), 10, 10);
assertResponse(
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf)),
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value, is(4L))
);
assertThat(numAsyncCalls.get(), equalTo(2));

// check that we use the rewritten vector to build the explain query
assertResponse(
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf).explain(true)),
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value, is(4L))
);
assertThat(numAsyncCalls.get(), equalTo(4));
}
}

0 comments on commit f9771dd

Please sign in to comment.