Skip to content

Commit

Permalink
iter
Browse files Browse the repository at this point in the history
  • Loading branch information
pmpailis committed Nov 7, 2024
1 parent 4dc5afc commit 3cf0ad5
Show file tree
Hide file tree
Showing 18 changed files with 124 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.retriever.rankdoc;
package org.elasticsearch.index.query;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Query;
Expand All @@ -16,15 +16,13 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQuery;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.TransportVersions.RRF_QUERY_REWRITE;
Expand Down Expand Up @@ -55,6 +53,17 @@ public RankDocsQueryBuilder(StreamInput in) throws IOException {
}
}

@Override
protected void extractInnerHitBuilders(Map<String, InnerHitContextBuilder> innerHits) {
if (queryBuilders != null) {
for (QueryBuilder query : queryBuilders) {
if (query instanceof AbstractQueryBuilder) {
((AbstractQueryBuilder<?>) query).extractInnerHitBuilders(innerHits);
}
}
}
}

@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
if (queryBuilders != null) {
Expand All @@ -71,7 +80,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
return super.doRewrite(queryRewriteContext);
}

RankDoc[] rankDocs() {
public RankDoc[] rankDocs() {
return rankDocs;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryStringQueryBuilder;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.index.query.RegexpQueryBuilder;
import org.elasticsearch.index.query.ScriptQueryBuilder;
import org.elasticsearch.index.query.SimpleQueryStringBuilder;
Expand Down Expand Up @@ -238,7 +239,6 @@
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.GeoDistanceSortBuilder;
import org.elasticsearch.search.sort.ScoreSortBuilder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ public static HighlightBuilder highlight() {

private Map<String, Object> runtimeMappings = emptyMap();

private boolean innerHitsDisabled = false;

/**
* Constructs a new search source builder.
*/
Expand Down Expand Up @@ -1838,6 +1840,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder;
}

public void innerHitsDisabled(boolean innerHitsDisabled) {
this.innerHitsDisabled = innerHitsDisabled;
}

public boolean innerHitsDisabled() {
return this.innerHitsDisabled;
}

public static class IndexBoost implements Writeable, ToXContentObject {
private final String index;
private final float boost;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ public SourceLoader sourceLoader() {
return sourceLoader;
}

public boolean innerHitsDisabled() {
return searchContext.innerHitsDisabled();
}

/**
* For a hit document that's being processed, return the source lookup representing the
* root document. This method is used to pass down the root source when processing this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public InnerHitsPhase(FetchPhase fetchPhase) {

@Override
public FetchSubPhaseProcessor getProcessor(FetchContext searchContext) {
if (searchContext.innerHits() == null || searchContext.innerHits().getInnerHits().isEmpty()) {
if (searchContext.innerHitsDisabled() || searchContext.innerHits() == null || searchContext.innerHits().getInnerHits().isEmpty()) {
return null;
}
Map<String, InnerHitsContext.InnerHitSubContext> innerHits = searchContext.innerHits().getInnerHits();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,4 +402,8 @@ public String toString() {
public abstract SourceLoader newSourceLoader();

public abstract IdLoader newIdLoader();

public boolean innerHitsDisabled() {
return request() != null && request().innerHitsDisabled();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ public class ShardSearchRequest extends TransportRequest implements IndicesReque
private Boolean requestCache;
private final long nowInMillis;
private final boolean allowPartialSearchResults;
public final boolean innerHitsDisabled;
private final OriginalIndices originalIndices;

private boolean canReturnNullResponseIfMatchNoDocs;
Expand Down Expand Up @@ -250,6 +251,7 @@ public ShardSearchRequest(
this.waitForCheckpoint = waitForCheckpoint;
this.waitForCheckpointsTimeout = waitForCheckpointsTimeout;
this.forceSyntheticSource = forceSyntheticSource;
this.innerHitsDisabled = source.innerHitsDisabled();
}

@SuppressWarnings("this-escape")
Expand All @@ -275,6 +277,7 @@ public ShardSearchRequest(ShardSearchRequest clone) {
this.waitForCheckpoint = clone.waitForCheckpoint;
this.waitForCheckpointsTimeout = clone.waitForCheckpointsTimeout;
this.forceSyntheticSource = clone.forceSyntheticSource;
this.innerHitsDisabled = clone.innerHitsDisabled;
}

public ShardSearchRequest(StreamInput in) throws IOException {
Expand Down Expand Up @@ -341,6 +344,11 @@ public ShardSearchRequest(StreamInput in) throws IOException {
*/
forceSyntheticSource = false;
}
if (in.getTransportVersion().onOrAfter(TransportVersion.current())) {
innerHitsDisabled = in.readBoolean();
} else {
innerHitsDisabled = false;
}
originalIndices = OriginalIndices.readOriginalIndices(in);
}

Expand Down Expand Up @@ -401,6 +409,9 @@ protected final void innerWriteTo(StreamOutput out, boolean asKey) throws IOExce
throw new IllegalArgumentException("force_synthetic_source is not supported before 8.4.0");
}
}
if (out.getTransportVersion().onOrAfter(TransportVersion.current())) {
out.writeBoolean(innerHitsDisabled);
}
}

@Override
Expand Down Expand Up @@ -482,6 +493,10 @@ public boolean allowPartialSearchResults() {
return allowPartialSearchResults;
}

public boolean innerHitsDisabled() {
return innerHitsDisabled;
}

public Scroll scroll() {
return scroll;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit,
}
sortBuilders.add(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME));
sourceBuilder.sort(sortBuilders);
sourceBuilder.innerHitsDisabled(true);
return sourceBuilder;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
Expand Down Expand Up @@ -240,6 +240,9 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
List<KnnSearchBuilder> knnSearchBuilders = new ArrayList<>(searchSourceBuilder.knnSearch());
knnSearchBuilders.add(knnSearchBuilder);
searchSourceBuilder.knnSearch(knnSearchBuilders);
if (compoundUsed) {
searchSourceBuilder.innerHitsDisabled(true);
}
}

// ---- FOR TESTING XCONTENT PARSING ----
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
if (collapseBuilder != null) {
searchSourceBuilder.collapse(collapseBuilder);
}
if (compoundUsed) {
searchSourceBuilder.innerHitsDisabled(true);
}
}

// ---- FOR TESTING XCONTENT PARSING ----
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ private static int[] findSegmentStarts(IndexReader reader, RankDoc[] docs) {
return starts;
}

RankDoc[] rankDocs() {
public RankDoc[] rankDocs() {
return docs;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.retriever.rankdoc;
package org.elasticsearch.index.query;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.NumericDocValuesField;
Expand All @@ -22,9 +22,8 @@
import org.apache.lucene.search.TopScoreDocCollectorManager;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQuery;
import org.elasticsearch.test.AbstractQueryTestCase;

import java.io.IOException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.test.AbstractWireSerializingTestCase;

import java.io.IOException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.RandomQueryBuilder;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.usage.SearchUsage;
import org.elasticsearch.xcontent.NamedXContentRegistry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.RandomQueryBuilder;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.test.ESTestCase;

import java.io.IOException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
Expand All @@ -19,7 +20,6 @@
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilderWrapper;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.search.sort.ScoreSortBuilder;
import org.elasticsearch.search.sort.SortBuilder;
import org.elasticsearch.xcontent.ConstructingObjectParser;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ setup:
properties:
views:
type: long
nested_inner_hits:
type: nested
properties:
data:
type: keyword

- do:
index:
Expand Down Expand Up @@ -125,6 +130,7 @@ setup:
integer: 2
keyword: "technology"
nested: { views: 10}
nested_inner_hits: [{"data": "foo"}, {"data": "bar"}, {"data": "baz"}]
- do:
indices.refresh: {}

Expand Down Expand Up @@ -960,3 +966,55 @@ setup:
- length: { hits.hits : 1 }

- match: { hits.hits.0._id: "1" }

---
"rrf retriever with inner_hits for sub-retriever":

- do:
search:
_source: false
index: test
body:
retriever:
rrf:
retrievers: [
{
standard: {
query: {
nested: {
path: nested_inner_hits,
inner_hits: {
_source: false,
"sort": [ {
"nested_inner_hits.data": "asc"
}
],
fields: [ nested_inner_hits.data ]
},
query: {
match_all: { }
}
}
}
}
},
{
standard: {
query: {
match_all: { }
}
}
}
]
rank_window_size: 10
rank_constant: 10
size: 2

- match: { hits.total.value: 9 }

- match: { hits.hits.0.inner_hits.nested_inner_hits.hits.total.value: 3 }
- match: { hits.hits.0.inner_hits.nested_inner_hits.hits.hits.0.fields.nested_inner_hits.0.data.0: bar }
- match: { hits.hits.0.inner_hits.nested_inner_hits.hits.hits.1.fields.nested_inner_hits.0.data.0: baz }
- match: { hits.hits.0.inner_hits.nested_inner_hits.hits.hits.2.fields.nested_inner_hits.0.data.0: foo }

- match: { hits.hits.1.inner_hits.nested_inner_hits.hits.total.value: 0 }

0 comments on commit 3cf0ad5

Please sign in to comment.