diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java index 882d44adb79c3..fa8d2ee1cbefc 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java @@ -218,6 +218,10 @@ public void setRankDocs(RankDoc[] rankDocs) { this.rankDocs = rankDocs; } + public RankDoc[] getRankDocs() { + return rankDocs; + } + /** * Gets the filters for this retriever. */ diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilderWrapper.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilderWrapper.java index aff60e9ffdf95..7c9e71dbf41f1 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilderWrapper.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilderWrapper.java @@ -74,6 +74,11 @@ public void setRankDocs(RankDoc[] rankDocs) { in.setRankDocs(rankDocs); } + @Override + public RankDoc[] getRankDocs() { + return in.getRankDocs(); + } + @Override public boolean isCompound() { return in.isCompound(); diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java index 07d9f199b80fe..730eb33421083 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java @@ -133,7 +133,7 @@ private static void checkValidSort(List> sortBuilders) { return; } - if (sortBuilders.size() > 1 || sortBuilders.get(0) instanceof ScoreSortBuilder == false) { + if (sortBuilders.get(0) instanceof ScoreSortBuilder == false) { throw new IllegalArgumentException("Rule retrievers can only sort documents by relevance score, got: " + sortBuilders); } } @@ -159,6 +159,7 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults) { for (int i = 0; i < scoreDocs.length; i++) { ScoreDoc scoreDoc = scoreDocs[i]; rankDocs[i] = new RankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex); + rankDocs[i].rank = i+1; } return rankDocs; } @@ -175,13 +176,13 @@ public int doHashCode() { } class QueryRuleRetrieverBuilderWrapper extends RetrieverBuilderWrapper { - protected QueryRuleRetrieverBuilderWrapper(RetrieverBuilder sub) { - super(sub); + protected QueryRuleRetrieverBuilderWrapper(RetrieverBuilder in) { + super(in); } @Override - protected QueryRuleRetrieverBuilderWrapper clone(RetrieverBuilder sub) { - return new QueryRuleRetrieverBuilderWrapper(sub); + protected QueryRuleRetrieverBuilderWrapper clone(RetrieverBuilder in) { + return new QueryRuleRetrieverBuilderWrapper(in); } @Override @@ -192,7 +193,7 @@ public QueryBuilder topDocsQuery() { @Override public QueryBuilder explainQuery() { return new RankDocsQueryBuilder( - rankDocs, + in.getRankDocs(), new QueryBuilder[] { new RuleQueryBuilder(in.explainQuery(), matchCriteria, rulesetIds) }, true );