Skip to content

Commit

Permalink
Propagating nested inner_hits to the parent compound retriever (elast…
Browse files Browse the repository at this point in the history
  • Loading branch information
pmpailis authored Nov 13, 2024
1 parent 5204902 commit 5b25dee
Show file tree
Hide file tree
Showing 19 changed files with 248 additions and 42 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/116408.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 116408
summary: Propagating nested `inner_hits` to the parent compound retriever
area: Ranking
type: bug
issues:
- 116397
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import org.elasticsearch.cluster.health.ClusterHealthStatus;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.InnerHitBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.sort.NestedSortBuilder;
import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.search.sort.SortMode;
Expand Down Expand Up @@ -1581,6 +1583,64 @@ public void testCheckFixedBitSetCache() throws Exception {
assertThat(clusterStatsResponse.getIndicesStats().getSegments().getBitsetMemoryInBytes(), equalTo(0L));
}

public void testSkipNestedInnerHits() throws Exception {
assertAcked(prepareCreate("test").setMapping("nested1", "type=nested"));
ensureGreen();

prepareIndex("test").setId("1")
.setSource(
jsonBuilder().startObject()
.field("field1", "value1")
.startArray("nested1")
.startObject()
.field("n_field1", "foo")
.field("n_field2", "bar")
.endObject()
.endArray()
.endObject()
)
.get();

waitForRelocation(ClusterHealthStatus.GREEN);
GetResponse getResponse = client().prepareGet("test", "1").get();
assertThat(getResponse.isExists(), equalTo(true));
assertThat(getResponse.getSourceAsBytesRef(), notNullValue());
refresh();

assertNoFailuresAndResponse(
prepareSearch("test").setSource(
new SearchSourceBuilder().query(
QueryBuilders.nestedQuery("nested1", QueryBuilders.termQuery("nested1.n_field1", "foo"), ScoreMode.Avg)
.innerHit(new InnerHitBuilder())
)
),
res -> {
assertNotNull(res.getHits());
assertHitCount(res, 1);
assertThat(res.getHits().getHits().length, equalTo(1));
// by default we should get inner hits
assertNotNull(res.getHits().getHits()[0].getInnerHits());
assertNotNull(res.getHits().getHits()[0].getInnerHits().get("nested1"));
}
);

assertNoFailuresAndResponse(
prepareSearch("test").setSource(
new SearchSourceBuilder().query(
QueryBuilders.nestedQuery("nested1", QueryBuilders.termQuery("nested1.n_field1", "foo"), ScoreMode.Avg)
.innerHit(new InnerHitBuilder())
).skipInnerHits(true)
),
res -> {
assertNotNull(res.getHits());
assertHitCount(res, 1);
assertThat(res.getHits().getHits().length, equalTo(1));
// if we explicitly say to ignore inner hits, then this should now be null
assertNull(res.getHits().getHits()[0].getInnerHits());
}
);
}

private void assertDocumentCount(String index, long numdocs) {
IndicesStatsResponse stats = indicesAdmin().prepareStats(index).clear().setDocs(true).get();
assertNoFailures(stats);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ static TransportVersion def(int id) {
public static final TransportVersion DATA_STREAM_INDEX_VERSION_DEPRECATION_CHECK = def(8_788_00_0);
public static final TransportVersion ADD_COMPATIBILITY_VERSIONS_TO_NODE_INFO = def(8_789_00_0);
public static final TransportVersion VERTEX_AI_INPUT_TYPE_ADDED = def(8_790_00_0);

public static final TransportVersion SKIP_INNER_HITS_SEARCH_SOURCE = def(8_791_00_0);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.retriever.rankdoc;
package org.elasticsearch.index.query;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Query;
Expand All @@ -16,15 +16,13 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQuery;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.TransportVersions.RRF_QUERY_REWRITE;
Expand Down Expand Up @@ -55,6 +53,15 @@ public RankDocsQueryBuilder(StreamInput in) throws IOException {
}
}

@Override
protected void extractInnerHitBuilders(Map<String, InnerHitContextBuilder> innerHits) {
if (queryBuilders != null) {
for (QueryBuilder query : queryBuilders) {
InnerHitContextBuilder.extractInnerHits(query, innerHits);
}
}
}

@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
if (queryBuilders != null) {
Expand All @@ -71,7 +78,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
return super.doRewrite(queryRewriteContext);
}

RankDoc[] rankDocs() {
public RankDoc[] rankDocs() {
return rankDocs;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ private SearchCapabilities() {}
private static final String KQL_QUERY_SUPPORTED = "kql_query";
/** Support multi-dense-vector field mapper. */
private static final String MULTI_DENSE_VECTOR_FIELD_MAPPER = "multi_dense_vector_field_mapper";
/** Support propagating nested retrievers' inner_hits to top-level compound retrievers . */
private static final String NESTED_RETRIEVER_INNER_HITS_SUPPORT = "nested_retriever_inner_hits_support";

public static final Set<String> CAPABILITIES;
static {
Expand All @@ -45,6 +47,7 @@ private SearchCapabilities() {}
capabilities.add(BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY);
capabilities.add(DENSE_VECTOR_DOCVALUE_FIELDS);
capabilities.add(TRANSFORM_RANK_RRF_TO_RETRIEVER);
capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT);
if (MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()) {
capabilities.add(MULTI_DENSE_VECTOR_FIELD_MAPPER);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryStringQueryBuilder;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.index.query.RegexpQueryBuilder;
import org.elasticsearch.index.query.ScriptQueryBuilder;
import org.elasticsearch.index.query.SimpleQueryStringBuilder;
Expand Down Expand Up @@ -238,7 +239,6 @@
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.GeoDistanceSortBuilder;
import org.elasticsearch.search.sort.ScoreSortBuilder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1285,13 +1285,17 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
);
if (query != null) {
QueryBuilder rewrittenForInnerHits = Rewriteable.rewrite(query, innerHitsRewriteContext, true);
InnerHitContextBuilder.extractInnerHits(rewrittenForInnerHits, innerHitBuilders);
if (false == source.skipInnerHits()) {
InnerHitContextBuilder.extractInnerHits(rewrittenForInnerHits, innerHitBuilders);
}
searchExecutionContext.setAliasFilter(context.request().getAliasFilter().getQueryBuilder());
context.parsedQuery(searchExecutionContext.toQuery(query));
}
if (source.postFilter() != null) {
QueryBuilder rewrittenForInnerHits = Rewriteable.rewrite(source.postFilter(), innerHitsRewriteContext, true);
InnerHitContextBuilder.extractInnerHits(rewrittenForInnerHits, innerHitBuilders);
if (false == source.skipInnerHits()) {
InnerHitContextBuilder.extractInnerHits(rewrittenForInnerHits, innerHitBuilders);
}
context.parsedPostFilter(searchExecutionContext.toQuery(source.postFilter()));
}
if (innerHitBuilders.size() > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ public static HighlightBuilder highlight() {

private Map<String, Object> runtimeMappings = emptyMap();

private boolean skipInnerHits = false;

/**
* Constructs a new search source builder.
*/
Expand Down Expand Up @@ -290,6 +292,11 @@ public SearchSourceBuilder(StreamInput in) throws IOException {
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) {
rankBuilder = in.readOptionalNamedWriteable(RankBuilder.class);
}
if (in.getTransportVersion().onOrAfter(TransportVersions.SKIP_INNER_HITS_SEARCH_SOURCE)) {
skipInnerHits = in.readBoolean();
} else {
skipInnerHits = false;
}
}

@Override
Expand Down Expand Up @@ -379,6 +386,9 @@ public void writeTo(StreamOutput out) throws IOException {
} else if (rankBuilder != null) {
throw new IllegalArgumentException("cannot serialize [rank] to version [" + out.getTransportVersion().toReleaseVersion() + "]");
}
if (out.getTransportVersion().onOrAfter(TransportVersions.SKIP_INNER_HITS_SEARCH_SOURCE)) {
out.writeBoolean(skipInnerHits);
}
}

/**
Expand Down Expand Up @@ -1280,6 +1290,7 @@ private SearchSourceBuilder shallowCopy(
rewrittenBuilder.collapse = collapse;
rewrittenBuilder.pointInTimeBuilder = pointInTimeBuilder;
rewrittenBuilder.runtimeMappings = runtimeMappings;
rewrittenBuilder.skipInnerHits = skipInnerHits;
return rewrittenBuilder;
}

Expand Down Expand Up @@ -1838,6 +1849,9 @@ public XContentBuilder innerToXContent(XContentBuilder builder, Params params) t
if (false == runtimeMappings.isEmpty()) {
builder.field(RUNTIME_MAPPINGS_FIELD.getPreferredName(), runtimeMappings);
}
if (skipInnerHits) {
builder.field("skipInnerHits", true);
}

return builder;
}
Expand All @@ -1850,6 +1864,15 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder;
}

public SearchSourceBuilder skipInnerHits(boolean skipInnerHits) {
this.skipInnerHits = skipInnerHits;
return this;
}

public boolean skipInnerHits() {
return this.skipInnerHits;
}

public static class IndexBoost implements Writeable, ToXContentObject {
private final String index;
private final float boost;
Expand Down Expand Up @@ -2104,7 +2127,8 @@ public int hashCode() {
collapse,
trackTotalHitsUpTo,
pointInTimeBuilder,
runtimeMappings
runtimeMappings,
skipInnerHits
);
}

Expand Down Expand Up @@ -2149,7 +2173,8 @@ public boolean equals(Object obj) {
&& Objects.equals(collapse, other.collapse)
&& Objects.equals(trackTotalHitsUpTo, other.trackTotalHitsUpTo)
&& Objects.equals(pointInTimeBuilder, other.pointInTimeBuilder)
&& Objects.equals(runtimeMappings, other.runtimeMappings);
&& Objects.equals(runtimeMappings, other.runtimeMappings)
&& Objects.equals(skipInnerHits, other.skipInnerHits);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ public int doHashCode() {
return Objects.hash(innerRetrievers);
}

protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
protected final SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
.trackTotalHits(false)
.storedFields(new StoredFieldsContext(false))
Expand All @@ -254,6 +254,11 @@ protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit,
}
sortBuilders.add(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME));
sourceBuilder.sort(sortBuilders);
sourceBuilder.skipInnerHits(true);
return finalizeSourceBuilder(sourceBuilder);
}

protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) {
return sourceBuilder;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ private static int[] findSegmentStarts(IndexReader reader, RankDoc[] docs) {
return starts;
}

RankDoc[] rankDocs() {
public RankDoc[] rankDocs() {
return docs;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.retriever.rankdoc;
package org.elasticsearch.index.query;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.NumericDocValuesField;
Expand All @@ -22,9 +22,8 @@
import org.apache.lucene.search.TopScoreDocCollectorManager;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQuery;
import org.elasticsearch.test.AbstractQueryTestCase;

import java.io.IOException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.test.AbstractWireSerializingTestCase;

import java.io.IOException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.RandomQueryBuilder;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.usage.SearchUsage;
import org.elasticsearch.xcontent.NamedXContentRegistry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.RandomQueryBuilder;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.test.ESTestCase;

import java.io.IOException;
Expand Down
Loading

0 comments on commit 5b25dee

Please sign in to comment.