From 3766d7c6bd1b4bb636c58f82020bf69600b1855a Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Wed, 16 Oct 2024 12:57:09 +0100 Subject: [PATCH] Add a rule query retriever reusing the compound builder --- .../retriever/CompoundRetrieverBuilder.java | 24 +--- .../retriever/RetrieverBuilderWrapper.java | 131 ++++++++++++++++++ .../retriever/QueryRuleRetrieverBuilder.java | 64 ++++++--- .../TextSimilarityRankRetrieverBuilder.java | 18 --- 4 files changed, 180 insertions(+), 57 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilderWrapper.java diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index 8b6a2c4e7b078..629c74d1ea539 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -18,7 +18,6 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.TransportMultiSearchAction; -import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.search.builder.PointInTimeBuilder; @@ -163,6 +162,11 @@ public final QueryBuilder topDocsQuery() { throw new IllegalStateException("Should not be called, missing a rewrite?"); } + @Override + public final QueryBuilder explainQuery() { + throw new IllegalStateException("Should not be called, missing a rewrite?"); + } + @Override public final void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { throw new IllegalStateException("Should not be called, missing a rewrite?"); @@ -218,23 +222,6 @@ protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, } retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true); - // apply the pre-filters - if (preFilterQueryBuilders.size() > 0) { - QueryBuilder query = sourceBuilder.query(); - BoolQueryBuilder newQuery = new BoolQueryBuilder(); - if (query != null) { - newQuery.must(query); - } - preFilterQueryBuilders.forEach(newQuery::filter); - sourceBuilder.query(newQuery); - } - - addSort(sourceBuilder); - - return sourceBuilder; - } - - protected void addSort(SearchSourceBuilder sourceBuilder) { // Record the shard id in the sort result List> sortBuilders = sourceBuilder.sorts() != null ? new ArrayList<>(sourceBuilder.sorts()) : new ArrayList<>(); if (sortBuilders.isEmpty()) { @@ -242,6 +229,7 @@ protected void addSort(SearchSourceBuilder sourceBuilder) { } sortBuilders.add(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME)); sourceBuilder.sort(sortBuilders); + return sourceBuilder; } private RankDoc[] getRankDocs(SearchResponse searchResponse) { diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilderWrapper.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilderWrapper.java new file mode 100644 index 0000000000000..aff60e9ffdf95 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilderWrapper.java @@ -0,0 +1,131 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.retriever; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; + +/** + * A wrapper that can be used to modify the behaviour of an existing {@link RetrieverBuilder}. + */ +public abstract class RetrieverBuilderWrapper extends RetrieverBuilder { + protected final RetrieverBuilder in; + + protected RetrieverBuilderWrapper(RetrieverBuilder in) { + this.in = in; + } + + protected abstract T clone(RetrieverBuilder sub); + + @Override + public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { + var inRewrite = in.rewrite(ctx); + if (inRewrite != in) { + return clone(inRewrite); + } + return this; + } + + @Override + public QueryBuilder topDocsQuery() { + return in.topDocsQuery(); + } + + @Override + public RetrieverBuilder minScore(Float minScore) { + return in.minScore(minScore); + } + + @Override + public List getPreFilterQueryBuilders() { + return in.preFilterQueryBuilders; + } + + @Override + public ActionRequestValidationException validate( + SearchSourceBuilder source, + ActionRequestValidationException validationException, + boolean allowPartialSearchResults + ) { + return in.validate(source, validationException, allowPartialSearchResults); + } + + @Override + public RetrieverBuilder retrieverName(String retrieverName) { + return in.retrieverName(retrieverName); + } + + @Override + public void setRankDocs(RankDoc[] rankDocs) { + in.setRankDocs(rankDocs); + } + + @Override + public boolean isCompound() { + return in.isCompound(); + } + + @Override + public QueryBuilder explainQuery() { + return in.explainQuery(); + } + + @Override + public Float minScore() { + return in.minScore(); + } + + @Override + public boolean isFragment() { + return in.isFragment(); + } + + @Override + public String toString() { + return in.toString(); + } + + @Override + public String retrieverName() { + return in.retrieverName(); + } + + @Override + public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { + in.extractToSearchSourceBuilder(searchSourceBuilder, compoundUsed); + } + + @Override + public String getName() { + return in.getName(); + } + + @Override + protected void doToXContent(XContentBuilder builder, Params params) throws IOException { + in.doToXContent(builder, params); + } + + @Override + protected boolean doEquals(Object o) { + return in.equals(o); + } + + @Override + protected int doHashCode() { + return in.doHashCode(); + } +} diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java index 6633b060ee9b0..07d9f199b80fe 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java @@ -14,12 +14,14 @@ import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.fetch.StoredFieldsContext; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverBuilderWrapper; import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder; +import org.elasticsearch.search.sort.ScoreSortBuilder; +import org.elasticsearch.search.sort.SortBuilder; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; @@ -29,6 +31,7 @@ import org.elasticsearch.xpack.core.XPackPlugin; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; @@ -93,9 +96,10 @@ public QueryRuleRetrieverBuilder( RetrieverBuilder retrieverBuilder, int rankWindowSize ) { - super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize); + super(new ArrayList<>(), rankWindowSize); this.rulesetIds = rulesetIds; this.matchCriteria = matchCriteria; + addChild(new QueryRuleRetrieverBuilderWrapper(retrieverBuilder)); } public QueryRuleRetrieverBuilder( @@ -118,21 +122,20 @@ public String getName() { @Override protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { - var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit) - .trackTotalHits(false) - .storedFields(new StoredFieldsContext(false)) - .size(rankWindowSize); - retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true); - - QueryBuilder query = sourceBuilder.query(); - if (query != null && query instanceof RuleQueryBuilder == false) { - QueryBuilder ruleQuery = new RuleQueryBuilder(query, matchCriteria, rulesetIds); - sourceBuilder.query(ruleQuery); - } + var ret = super.createSearchSourceBuilder(pit, retrieverBuilder); + checkValidSort(ret.sorts()); + ret.query(new RuleQueryBuilder(ret.query(), matchCriteria, rulesetIds)); + return ret; + } - addSort(sourceBuilder); + private static void checkValidSort(List> sortBuilders) { + if (sortBuilders.isEmpty()) { + return; + } - return sourceBuilder; + if (sortBuilders.size() > 1 || sortBuilders.get(0) instanceof ScoreSortBuilder == false) { + throw new IllegalArgumentException("Rule retrievers can only sort documents by relevance score, got: " + sortBuilders); + } } @Override @@ -160,12 +163,6 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults) { return rankDocs; } - @Override - public QueryBuilder explainQuery() { - // the original matching set of the QueryRuleRetriever retriever is specified by its nested retriever - return new RankDocsQueryBuilder(rankDocs, new QueryBuilder[] { innerRetrievers.getFirst().retriever().explainQuery() }, true); - } - @Override public boolean doEquals(Object o) { QueryRuleRetrieverBuilder that = (QueryRuleRetrieverBuilder) o; @@ -176,4 +173,29 @@ public boolean doEquals(Object o) { public int doHashCode() { return Objects.hash(super.doHashCode(), rulesetIds, matchCriteria); } + + class QueryRuleRetrieverBuilderWrapper extends RetrieverBuilderWrapper { + protected QueryRuleRetrieverBuilderWrapper(RetrieverBuilder sub) { + super(sub); + } + + @Override + protected QueryRuleRetrieverBuilderWrapper clone(RetrieverBuilder sub) { + return new QueryRuleRetrieverBuilderWrapper(sub); + } + + @Override + public QueryBuilder topDocsQuery() { + return new RuleQueryBuilder(in.topDocsQuery(), matchCriteria, rulesetIds); + } + + @Override + public QueryBuilder explainQuery() { + return new RankDocsQueryBuilder( + rankDocs, + new QueryBuilder[] { new RuleQueryBuilder(in.explainQuery(), matchCriteria, rulesetIds) }, + true + ); + } + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index 8bccf6e7d1022..02603edbc7819 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -10,7 +10,6 @@ import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.common.ParsingException; import org.elasticsearch.features.NodeFeature; -import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.search.builder.PointInTimeBuilder; @@ -20,7 +19,6 @@ import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; -import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; @@ -158,12 +156,6 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults) { return textSimilarityRankDocs; } - @Override - public QueryBuilder explainQuery() { - // the original matching set of the TextSimilarityRank retriever is specified by its nested retriever - return new RankDocsQueryBuilder(rankDocs, new QueryBuilder[] { innerRetrievers.getFirst().retriever().explainQuery() }, true); - } - @Override protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit) @@ -175,16 +167,6 @@ protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, } retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true); - // apply the pre-filters - if (preFilterQueryBuilders.size() > 0) { - QueryBuilder query = sourceBuilder.query(); - BoolQueryBuilder newQuery = new BoolQueryBuilder(); - if (query != null) { - newQuery.must(query); - } - preFilterQueryBuilders.forEach(newQuery::filter); - sourceBuilder.query(newQuery); - } sourceBuilder.rankBuilder( new TextSimilarityRankBuilder(this.field, this.inferenceId, this.inferenceText, this.rankWindowSize, this.minScore) );