Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jimczi committed Jun 10, 2024
1 parent 3e725d1 commit 9cf29f8
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ void executeRequest(
frozenIndexCheck(resolvedIndices);
}

var retriever = searchRequest.source().retriever();
var retriever = searchRequest.source().consumeRetriever();
ActionListener<SearchRequest> rewriteSearchRequestListener = listener.delegateFailureAndWrap((delegate, rewritten) -> {
if (ccsCheckCompatibility) {
checkCCSVersionCompatibility(rewritten);
Expand Down Expand Up @@ -476,8 +476,9 @@ void executeRequest(
);
return;
}
searchRequest.setPreFilterShardSize(Integer.MAX_VALUE);
if (retriever.requiresPointInTime() && searchRequest.source().pointInTimeBuilder() == null) {
// The can match phase can reorder shards, so we disable it to ensure the stable ordering
searchRequest.setPreFilterShardSize(Integer.MAX_VALUE);
rewriteSearchRequestListener = ActionListener.releaseAfter(
rewriteSearchRequestListener,
() -> closePIT(searchRequest.source().pointInTimeBuilder())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,10 @@ public void writeTo(StreamOutput out) throws IOException {
}
}

public RetrieverBuilder retriever() {
return retrieverBuilder;
public RetrieverBuilder consumeRetriever() {
var ret = retrieverBuilder;
this.retrieverBuilder = null;
return ret;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class StoredFieldsContext implements Writeable {
private final List<String> fieldNames;
private final boolean fetchFields;

private StoredFieldsContext(boolean fetchFields) {
public StoredFieldsContext(boolean fetchFields) {
this.fetchFields = fetchFields;
this.fieldNames = null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
@Override
protected boolean doEquals(Object o) {
RankDocsRetrieverBuilder other = (RankDocsRetrieverBuilder) o;
return Arrays.equals(rankDocs.get(), other.rankDocs.get());
return Arrays.equals(rankDocs.get(), other.rankDocs.get()) && sources.equals(other.sources);
}

@Override
protected int doHashCode() {
return Objects.hash(super.hashCode(), windowSize, rankDocs.get());
return Objects.hash(super.hashCode(), windowSize);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ public ShardDocSortField(int shardRequestIndex, boolean reverse) {
this.shardRequestIndex = shardRequestIndex;
}

public static int shardRequestIndex(long value) {
return (int) (value >> 32);
}

int getShardRequestIndex() {
return shardRequestIndex;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.StoredFieldsContext;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.rank.RankDoc.RankKey;
import org.elasticsearch.search.retriever.RankDocsRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.ScoreSortBuilder;
import org.elasticsearch.search.sort.ShardDocSortField;
import org.elasticsearch.search.sort.SortBuilder;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
Expand Down Expand Up @@ -123,29 +128,16 @@ public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
}
List<RetrieverSource> newRetrievers = new ArrayList<>();
boolean hasChanged = false;
for (var source : retrievers) {
RetrieverBuilder rewritten = source.retriever.rewrite(ctx);
if (rewritten != source.retriever) {
newRetrievers.add(new RetrieverSource(rewritten, null));
hasChanged |= rewritten != source.retriever;
} else if (rewritten == source.retriever) {
SearchSourceBuilder sourceBuilder;
if (source.source == null) {
sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(ctx.pointInTimeBuilder()).size(rankWindowSize);
rewritten.extractToSearchSourceBuilder(sourceBuilder, false);
List<SortBuilder<?>> sortBuilders = sourceBuilder.sorts() != null
? new ArrayList<>(sourceBuilder.sorts())
: new ArrayList<>();
if (sortBuilders.isEmpty()) {
sortBuilders.add(new ScoreSortBuilder());
}
sourceBuilder.sort(sortBuilders);
} else {
sourceBuilder = source.source;
}
for (var entry : retrievers) {
RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx);
if (newRetriever != entry.retriever) {
newRetrievers.add(new RetrieverSource(newRetriever, null));
hasChanged |= newRetriever != entry.retriever;
} else if (newRetriever == entry.retriever) {
var sourceBuilder = entry.source != null ? entry.source : createSearchSourceBuilder(ctx.pointInTimeBuilder(), newRetriever);
var rewrittenSource = sourceBuilder.rewrite(ctx);
newRetrievers.add(new RetrieverSource(rewritten, rewrittenSource));
hasChanged |= rewrittenSource != source.source;
newRetrievers.add(new RetrieverSource(newRetriever, rewrittenSource));
hasChanged |= rewrittenSource != entry.source;
}
}
if (hasChanged) {
Expand All @@ -157,8 +149,9 @@ public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
}

MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
for (var ret : retrievers) {
SearchRequest searchRequest = new SearchRequest().source(ret.source);
for (var entry : retrievers) {
SearchRequest searchRequest = new SearchRequest().source(entry.source);
// The can match phase can reorder shards, so we disable it to ensure the stable ordering
searchRequest.setPreFilterShardSize(Integer.MAX_VALUE);
multiSearchRequest.add(searchRequest);
}
Expand Down Expand Up @@ -221,14 +214,29 @@ public int doHashCode() {
return Objects.hash(retrievers, rankWindowSize, rankConstant);
}

private SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
.trackTotalHits(false)
.storedFields(new StoredFieldsContext(false))
.size(rankWindowSize);
retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, false);
// Record the shard id in the sort result
List<SortBuilder<?>> sortBuilders = sourceBuilder.sorts() != null ? new ArrayList<>(sourceBuilder.sorts()) : new ArrayList<>();
if (sortBuilders.isEmpty()) {
sortBuilders.add(new ScoreSortBuilder());
}
sortBuilders.add(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME));
sourceBuilder.sort(sortBuilders);
return sourceBuilder;
}

private ScoreDoc[] getTopDocs(SearchResponse searchResponse) {
int size = Math.min(rankWindowSize, searchResponse.getHits().getHits().length);
ScoreDoc[] docs = new ScoreDoc[size];
for (int i = 0; i < size; i++) {
var hit = searchResponse.getHits().getAt(i);
long sortValue = (long) hit.getRawSortValues()[hit.getRawSortValues().length - 1];
int shardIndex = (int) (sortValue >> 32);
docs[i] = new ScoreDoc(hit.docId(), hit.getScore(), shardIndex);
int shardRequestIndex = ShardDocSortField.shardRequestIndex((long) hit.getRawSortValues()[hit.getRawSortValues().length - 1]);
docs[i] = new ScoreDoc(hit.docId(), hit.getScore(), shardRequestIndex);
}
return docs;
}
Expand All @@ -240,15 +248,14 @@ public RRFRankDoc[] combineQueryPhaseResults(List<ScoreDoc[]> rankResults) {
// if a doc isn't part of a result set its position will be NO_RANK [0] and
// its score is [0f]
int queries = rankResults.size();
Map<Long, RRFRankDoc> docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize);
Map<RankDoc.RankKey, RRFRankDoc> docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize);
int index = 0;
for (var rrfRankResult : rankResults) {
int rank = 1;
for (ScoreDoc scoreDoc : rrfRankResult) {
final int findex = index;
final int frank = rank;
long docAndShard = (((long) scoreDoc.shardIndex) << 32) | (scoreDoc.doc & 0xFFFFFFFFL);
docsToRankResults.compute(docAndShard, (key, value) -> {
docsToRankResults.compute(new RankKey(scoreDoc.doc, scoreDoc.shardIndex), (key, value) -> {
if (value == null) {
value = new RRFRankDoc(scoreDoc.doc, scoreDoc.shardIndex, queries);
}
Expand Down

0 comments on commit 9cf29f8

Please sign in to comment.