From 9ec025ec6dc10f2f29c78c37d0f305ccda2c4f3c Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Mon, 17 Jun 2024 15:56:13 +0100 Subject: [PATCH] Add an abstract combine retriever and make RRF a simple extension --- .../uhighlight/CustomUnifiedHighlighter.java | 5 +- .../retriever/CombineRetrieverBuilder.java | 198 +++++++++++++ .../search/retriever/RankDocsQuery.java | 2 +- .../xpack/rank/rrf/RRFRetrieverBuilder.java | 259 ++++-------------- .../rrf/RRFRetrieverBuilderParsingTests.java | 15 +- 5 files changed, 259 insertions(+), 220 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/search/retriever/CombineRetrieverBuilder.java diff --git a/server/src/main/java/org/elasticsearch/lucene/search/uhighlight/CustomUnifiedHighlighter.java b/server/src/main/java/org/elasticsearch/lucene/search/uhighlight/CustomUnifiedHighlighter.java index 07eec973c77e0..dea699c0fc93f 100644 --- a/server/src/main/java/org/elasticsearch/lucene/search/uhighlight/CustomUnifiedHighlighter.java +++ b/server/src/main/java/org/elasticsearch/lucene/search/uhighlight/CustomUnifiedHighlighter.java @@ -32,6 +32,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; +import org.elasticsearch.search.retriever.RankDocsQuery; import org.elasticsearch.search.runtime.AbstractScriptFieldQuery; import org.elasticsearch.search.vectors.KnnScoreDocQuery; @@ -255,10 +256,10 @@ public void visitLeaf(Query leafQuery) { hasUnknownLeaf[0] = true; } /** - * KnnScoreDocQuery requires the same reader that built the docs + * {@link KnnScoreDocQuery} and {@link RankDocsQuery} require the same reader that built the docs * When using {@link HighlightFlag#WEIGHT_MATCHES} different readers are used and isn't supported by this query */ - if (leafQuery instanceof KnnScoreDocQuery) { + if (leafQuery instanceof KnnScoreDocQuery || leafQuery instanceof RankDocsQuery) { hasUnknownLeaf[0] = true; } super.visitLeaf(query); diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CombineRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CombineRetrieverBuilder.java new file mode 100644 index 0000000000000..7075f0741d7c2 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/retriever/CombineRetrieverBuilder.java @@ -0,0 +1,198 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.retriever; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.MultiSearchRequest; +import org.elasticsearch.action.search.MultiSearchResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.TransportMultiSearchAction; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.search.builder.PointInTimeBuilder; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.fetch.StoredFieldsContext; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.sort.FieldSortBuilder; +import org.elasticsearch.search.sort.ScoreSortBuilder; +import org.elasticsearch.search.sort.ShardDocSortField; +import org.elasticsearch.search.sort.SortBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +/** + * This abstract retriever is a compound retriever. + * It has a set of child retrievers, each returning a set of top documents. + * These documents are then combined and ranked according to the implementation of {@code combineQueryPhaseResults}. + */ +public abstract class CombineRetrieverBuilder> extends RetrieverBuilder { + public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {} + + protected final int windowSize; + protected final List childRetrievers; + + protected CombineRetrieverBuilder(List childRetrievers, int windowSize) { + this.windowSize = windowSize; + this.childRetrievers = childRetrievers; + } + + @SuppressWarnings("unchecked") + public T addChild(RetrieverBuilder retrieverBuilder) { + childRetrievers.add(new RetrieverSource(retrieverBuilder, null)); + return (T) this; + } + + /** + * Returns a clone of the original retriever, replacing the sub-retrievers with + * the provided {@code newChildRetrievers}. + */ + public abstract T clone(T original, List newChildRetrievers); + + /** + * Combines the provided {@code rankResults} to return the final top documents. + */ + public abstract RankDoc[] combineQueryPhaseResults(List rankResults); + + @Override + public final boolean isCompound() { + return true; + } + + @Override + @SuppressWarnings("unchecked") + public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { + if (ctx.pointInTimeBuilder() == null) { + throw new IllegalStateException("PIT is required"); + } + + // Rewrite prefilters + boolean hasChanged = false; + var newPreFilters = rewritePreFilters(ctx); + hasChanged |= newPreFilters != preFilterQueryBuilders; + + // Rewrite retriever sources + List newRetrievers = new ArrayList<>(); + for (var entry : childRetrievers) { + RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx); + if (newRetriever != entry.retriever) { + newRetrievers.add(new RetrieverSource(newRetriever, null)); + hasChanged |= newRetriever != entry.retriever; + } else if (newRetriever == entry.retriever) { + var sourceBuilder = entry.source != null ? entry.source : createSearchSourceBuilder(ctx.pointInTimeBuilder(), newRetriever); + var rewrittenSource = sourceBuilder.rewrite(ctx); + newRetrievers.add(new RetrieverSource(newRetriever, rewrittenSource)); + hasChanged |= rewrittenSource != entry.source; + } + } + if (hasChanged) { + return clone((T) this, newRetrievers); + } + + // execute searches + final SetOnce results = new SetOnce<>(); + final MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); + for (var entry : childRetrievers) { + SearchRequest searchRequest = new SearchRequest().source(entry.source); + // The can match phase can reorder shards, so we disable it to ensure the stable ordering + searchRequest.setPreFilterShardSize(Integer.MAX_VALUE); + multiSearchRequest.add(searchRequest); + } + ctx.registerAsyncAction((client, listener) -> { + client.execute(TransportMultiSearchAction.TYPE, multiSearchRequest, new ActionListener<>() { + @Override + public void onResponse(MultiSearchResponse items) { + List topDocs = new ArrayList<>(); + for (int i = 0; i < items.getResponses().length; i++) { + var item = items.getResponses()[i]; + topDocs.add(getTopDocs(item.getResponse())); + } + results.set(combineQueryPhaseResults(topDocs)); + listener.onResponse(null); + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }); + }); + + return new RankDocsRetrieverBuilder(windowSize, newRetrievers.stream().map(s -> s.retriever).toList(), results::get, newPreFilters); + } + + @Override + public final QueryBuilder topDocsQuery(QueryBuilder leadQuery) { + throw new IllegalStateException(getName() + " cannot be nested"); + } + + @Override + public final void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { + throw new IllegalStateException("Should not be called, missing a rewrite?"); + } + + @Override + @SuppressWarnings("unchecked") + public boolean doEquals(Object o) { + CombineRetrieverBuilder that = (CombineRetrieverBuilder) o; + return windowSize == that.windowSize && Objects.equals(childRetrievers, that.childRetrievers); + } + + @Override + public int doHashCode() { + return Objects.hash(childRetrievers, windowSize); + } + + private SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { + var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit) + .trackTotalHits(false) + .storedFields(new StoredFieldsContext(false)) + .size(windowSize); + retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, false); + + // apply the pre-filters + if (preFilterQueryBuilders.size() > 0) { + QueryBuilder query = sourceBuilder.query(); + BoolQueryBuilder newQuery = new BoolQueryBuilder(); + if (query != null) { + newQuery.must(query); + } + preFilterQueryBuilders.stream().forEach(newQuery::filter); + sourceBuilder.query(newQuery); + } + + // Record the shard id in the sort result + List> sortBuilders = sourceBuilder.sorts() != null ? new ArrayList<>(sourceBuilder.sorts()) : new ArrayList<>(); + if (sortBuilders.isEmpty()) { + sortBuilders.add(new ScoreSortBuilder()); + } + sortBuilders.add(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME)); + sourceBuilder.sort(sortBuilders); + return sourceBuilder; + } + + private ScoreDoc[] getTopDocs(SearchResponse searchResponse) { + int size = Math.min(windowSize, searchResponse.getHits().getHits().length); + ScoreDoc[] docs = new ScoreDoc[size]; + for (int i = 0; i < size; i++) { + var hit = searchResponse.getHits().getAt(i); + long sortValue = (long) hit.getRawSortValues()[hit.getRawSortValues().length - 1]; + int doc = ShardDocSortField.decodeDoc(sortValue); + int shardRequestIndex = ShardDocSortField.decodeShardRequestIndex(sortValue); + docs[i] = new ScoreDoc(doc, hit.getScore(), shardRequestIndex); + } + return docs; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsQuery.java b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsQuery.java index a766a51cba958..01c51e832cd97 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsQuery.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsQuery.java @@ -156,7 +156,7 @@ public boolean isCacheable(LeafReaderContext ctx) { @Override public String toString(String field) { - return "ScoreAndDocQuery"; + return "RankDocsQuery"; } @Override diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java index 5695c1c4ad9c4..f2511fa1d8046 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java @@ -8,32 +8,15 @@ package org.elasticsearch.xpack.rank.rrf; import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.util.SetOnce; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.search.MultiSearchRequest; -import org.elasticsearch.action.search.MultiSearchResponse; -import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.action.search.TransportMultiSearchAction; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.util.Maps; import org.elasticsearch.features.NodeFeature; -import org.elasticsearch.index.query.BoolQueryBuilder; -import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.license.LicenseUtils; -import org.elasticsearch.search.builder.PointInTimeBuilder; -import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.fetch.StoredFieldsContext; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.RankDoc.RankKey; -import org.elasticsearch.search.retriever.RankDocsRetrieverBuilder; +import org.elasticsearch.search.retriever.CombineRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; -import org.elasticsearch.search.sort.FieldSortBuilder; -import org.elasticsearch.search.sort.ScoreSortBuilder; -import org.elasticsearch.search.sort.ShardDocSortField; -import org.elasticsearch.search.sort.SortBuilder; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; @@ -52,7 +35,11 @@ import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; -public class RRFRetrieverBuilder extends RetrieverBuilder { +/** + * A {@link CombineRetrieverBuilder} that combines top documents from the sub-retrievers using + * Reciprocal Rank Fusion (RRF). + */ +public final class RRFRetrieverBuilder extends CombineRetrieverBuilder { public static final String NAME = "rrf"; public static final NodeFeature RRF_RETRIEVER_SUPPORTED = new NodeFeature("rrf_retriever_supported"); @@ -65,9 +52,13 @@ public class RRFRetrieverBuilder extends RetrieverBuilder { NAME, false, args -> { + List childRetrievers = (List) args[0]; + List childRetrieverSources = childRetrievers.stream() + .map(r -> new RetrieverSource(r, null)) + .collect(Collectors.toList()); int rankWindowSize = args[1] == null ? RRFRankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[1]; int rankConstant = args[2] == null ? RRFRankBuilder.DEFAULT_RANK_CONSTANT : (int) args[2]; - return new RRFRetrieverBuilder((List) args[0], rankWindowSize, rankConstant); + return new RRFRetrieverBuilder(childRetrieverSources, rankWindowSize, rankConstant); } ); @@ -94,46 +85,15 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP return PARSER.apply(parser, context); } - private record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {} - - private final List retrievers; - private final int rankWindowSize; private final int rankConstant; - private final SetOnce rankDocsSupplier; - public RRFRetrieverBuilder(List retrieverBuilders, int rankWindowSize, int rankConstant) { - this( - retrieverBuilders.stream().map(r -> new RetrieverSource(r, null)).collect(Collectors.toList()), - rankWindowSize, - rankConstant, - null - ); + public RRFRetrieverBuilder(int rankWindowSize, int rankConstant) { + this(new ArrayList<>(), rankWindowSize, rankConstant); } - private RRFRetrieverBuilder( - List retrievers, - int rankWindowSize, - int rankConstant, - SetOnce rankDocsSupplier - ) { - this.retrievers = retrievers; - this.rankWindowSize = rankWindowSize; + private RRFRetrieverBuilder(List childRetrievers, int rankWindowSize, int rankConstant) { + super(childRetrievers, rankWindowSize); this.rankConstant = rankConstant; - this.rankDocsSupplier = rankDocsSupplier; - } - - private RRFRetrieverBuilder( - RRFRetrieverBuilder clone, - List preFilterQueryBuilders, - List retrievers, - SetOnce rankDocsSupplier - ) { - super(clone); - this.preFilterQueryBuilders = preFilterQueryBuilders; - this.rankWindowSize = clone.rankWindowSize; - this.rankConstant = clone.rankConstant; - this.retrievers = retrievers; - this.rankDocsSupplier = rankDocsSupplier; } @Override @@ -142,161 +102,11 @@ public String getName() { } @Override - public boolean isCompound() { - return true; - } - - @Override - public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { - if (ctx.pointInTimeBuilder() == null) { - throw new IllegalStateException("PIT is required"); - } - - if (rankDocsSupplier != null) { - return this; - } - - // Rewrite prefilters - boolean hasChanged = false; - var newPreFilters = rewritePreFilters(ctx); - hasChanged |= newPreFilters != preFilterQueryBuilders; - - // Rewrite retriever sources - List newRetrievers = new ArrayList<>(); - for (var entry : retrievers) { - RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx); - if (newRetriever != entry.retriever) { - newRetrievers.add(new RetrieverSource(newRetriever, null)); - hasChanged |= newRetriever != entry.retriever; - } else if (newRetriever == entry.retriever) { - var sourceBuilder = entry.source != null ? entry.source : createSearchSourceBuilder(ctx.pointInTimeBuilder(), newRetriever); - var rewrittenSource = sourceBuilder.rewrite(ctx); - newRetrievers.add(new RetrieverSource(newRetriever, rewrittenSource)); - hasChanged |= rewrittenSource != entry.source; - } - } - if (hasChanged) { - return new RRFRetrieverBuilder(this, newPreFilters, newRetrievers, null); - } - - // execute searches - final SetOnce results = new SetOnce<>(); - final MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); - for (var entry : retrievers) { - SearchRequest searchRequest = new SearchRequest().source(entry.source); - // The can match phase can reorder shards, so we disable it to ensure the stable ordering - searchRequest.setPreFilterShardSize(Integer.MAX_VALUE); - multiSearchRequest.add(searchRequest); - } - ctx.registerAsyncAction((client, listener) -> { - client.execute(TransportMultiSearchAction.TYPE, multiSearchRequest, new ActionListener<>() { - @Override - public void onResponse(MultiSearchResponse items) { - List topDocs = new ArrayList<>(); - for (int i = 0; i < items.getResponses().length; i++) { - var item = items.getResponses()[i]; - topDocs.add(getTopDocs(item.getResponse())); - } - results.set(combineQueryPhaseResults(topDocs)); - listener.onResponse(null); - } - - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } - }); - }); - - return new RankDocsRetrieverBuilder( - rankWindowSize, - newRetrievers.stream().map(s -> s.retriever).toList(), - results::get, - newPreFilters - ); - } - - @Override - public QueryBuilder topDocsQuery(QueryBuilder leadQuery) { - throw new IllegalStateException(NAME + " cannot be nested"); - } - - @Override - public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { - throw new IllegalStateException("Should not be called, missing a rewrite?"); - } - - // ---- FOR TESTING XCONTENT PARSING ---- - @Override - public void doToXContent(XContentBuilder builder, Params params) throws IOException { - if (retrievers.isEmpty() == false) { - builder.startArray(RETRIEVERS_FIELD.getPreferredName()); - - for (var entry : retrievers) { - builder.startObject(); - builder.field(entry.retriever.getName()); - entry.retriever.toXContent(builder, params); - builder.endObject(); - } - builder.endArray(); - } - - builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); - builder.field(RANK_CONSTANT_FIELD.getPreferredName(), rankConstant); + public RRFRetrieverBuilder clone(RRFRetrieverBuilder original, List newRetrievers) { + return new RRFRetrieverBuilder(newRetrievers, original.windowSize, original.rankConstant); } @Override - public boolean doEquals(Object o) { - RRFRetrieverBuilder that = (RRFRetrieverBuilder) o; - return rankWindowSize == that.rankWindowSize && rankConstant == that.rankConstant && Objects.equals(retrievers, that.retrievers); - } - - @Override - public int doHashCode() { - return Objects.hash(retrievers, rankWindowSize, rankConstant); - } - - private SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { - var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit) - .trackTotalHits(false) - .storedFields(new StoredFieldsContext(false)) - .size(rankWindowSize); - retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, false); - - // apply the pre-filters - if (preFilterQueryBuilders.size() > 0) { - QueryBuilder query = sourceBuilder.query(); - BoolQueryBuilder newQuery = new BoolQueryBuilder(); - if (query != null) { - newQuery.must(query); - } - preFilterQueryBuilders.stream().forEach(newQuery::filter); - sourceBuilder.query(newQuery); - } - - // Record the shard id in the sort result - List> sortBuilders = sourceBuilder.sorts() != null ? new ArrayList<>(sourceBuilder.sorts()) : new ArrayList<>(); - if (sortBuilders.isEmpty()) { - sortBuilders.add(new ScoreSortBuilder()); - } - sortBuilders.add(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME)); - sourceBuilder.sort(sortBuilders); - return sourceBuilder; - } - - private ScoreDoc[] getTopDocs(SearchResponse searchResponse) { - int size = Math.min(rankWindowSize, searchResponse.getHits().getHits().length); - ScoreDoc[] docs = new ScoreDoc[size]; - for (int i = 0; i < size; i++) { - var hit = searchResponse.getHits().getAt(i); - long sortValue = (long) hit.getRawSortValues()[hit.getRawSortValues().length - 1]; - int doc = ShardDocSortField.decodeDoc(sortValue); - int shardRequestIndex = ShardDocSortField.decodeShardRequestIndex(sortValue); - docs[i] = new ScoreDoc(doc, hit.getScore(), shardRequestIndex); - } - return docs; - } - public RRFRankDoc[] combineQueryPhaseResults(List rankResults) { // combine the disjointed sets of TopDocs into a single set or RRFRankDocs // each RRFRankDoc will have both the position and score for each query where @@ -304,7 +114,7 @@ public RRFRankDoc[] combineQueryPhaseResults(List rankResults) { // if a doc isn't part of a result set its position will be NO_RANK [0] and // its score is [0f] int queries = rankResults.size(); - Map docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize); + Map docsToRankResults = Maps.newMapWithExpectedSize(windowSize); int index = 0; for (var rrfRankResult : rankResults) { int rank = 1; @@ -357,7 +167,7 @@ public RRFRankDoc[] combineQueryPhaseResults(List rankResults) { }); // trim the results if needed, otherwise each shard will always return `rank_window_size` results. // pagination and all else will happen on the coordinator when combining the shard responses - RRFRankDoc[] topResults = new RRFRankDoc[Math.min(rankWindowSize, sortedResults.length)]; + RRFRankDoc[] topResults = new RRFRankDoc[Math.min(windowSize, sortedResults.length)]; for (int rank = 0; rank < topResults.length; ++rank) { topResults[rank] = sortedResults[rank]; topResults[rank].rank = rank + 1; @@ -365,4 +175,35 @@ public RRFRankDoc[] combineQueryPhaseResults(List rankResults) { } return topResults; } + + // ---- FOR TESTING XCONTENT PARSING ---- + + @Override + public boolean doEquals(Object o) { + RRFRetrieverBuilder that = (RRFRetrieverBuilder) o; + return super.doEquals(o) && rankConstant == that.rankConstant; + } + + @Override + public int doHashCode() { + return Objects.hash(super.doHashCode(), rankConstant); + } + + @Override + public void doToXContent(XContentBuilder builder, Params params) throws IOException { + if (childRetrievers.isEmpty() == false) { + builder.startArray(RETRIEVERS_FIELD.getPreferredName()); + + for (var entry : childRetrievers) { + builder.startObject(); + builder.field(entry.retriever().getName()); + entry.retriever().toXContent(builder, params); + builder.endObject(); + } + builder.endArray(); + } + + builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), windowSize); + builder.field(RANK_CONSTANT_FIELD.getPreferredName(), rankConstant); + } } diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java index c350e3f0d6ed0..e360237371a82 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java @@ -29,13 +29,6 @@ public class RRFRetrieverBuilderParsingTests extends AbstractXContentTestCase retrieverBuilders = new ArrayList<>(retrieverCount); - - while (retrieverCount > 0) { - retrieverBuilders.add(TestRetrieverBuilder.createRandomTestRetrieverBuilder()); - --retrieverCount; - } int rankWindowSize = RRFRankBuilder.DEFAULT_RANK_WINDOW_SIZE; if (randomBoolean()) { rankWindowSize = randomIntBetween(1, 10000); @@ -44,7 +37,13 @@ public static RRFRetrieverBuilder createRandomRRFRetrieverBuilder() { if (randomBoolean()) { rankConstant = randomIntBetween(1, 1000000); } - return new RRFRetrieverBuilder(retrieverBuilders, rankWindowSize, rankConstant); + var ret = new RRFRetrieverBuilder(rankWindowSize, rankConstant); + int retrieverCount = randomIntBetween(2, 50); + while (retrieverCount > 0) { + ret.addChild(TestRetrieverBuilder.createRandomTestRetrieverBuilder()); + --retrieverCount; + } + return ret; } @Override