Skip to content

Commit

Permalink
Address concurrency issue in top hits aggregation (#106990)
Browse files Browse the repository at this point in the history
Top hits aggregation runs the fetch phase concurrently when the query phase is executed across multiple slices. This is problematic as the fetch phase does not support concurrent execution yet.

The core of the issue is that the search execution context is shared across slices, which call setLookupProviders against it concurrently, setting each time different instances of preloaded source and field lookup providers. This makes us cross streams between slices, and hit lucene assertions that ensure that stored fields loaded from a certain thread are not read from a different thread.

We have not hit this before because the problem revolves around SearchLookup which is used by runtime fields. TopHitsIT is the main test we have for top hits agg, but it uses a mock script engine which bypasses painless and SearchLookup.
  • Loading branch information
javanna authored Apr 4, 2024
1 parent a32512f commit d6582cf
Show file tree
Hide file tree
Showing 9 changed files with 184 additions and 115 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/106990.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 106990
summary: Address concurrency issue in top hits aggregation
area: Aggregations
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,22 @@
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.aggregations.InternalMultiBucketAggregation;
import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation;
import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.global.InternalGlobal;
import org.elasticsearch.search.aggregations.metrics.InternalTopHits;
import org.elasticsearch.search.aggregations.metrics.TopHitsAggregationBuilder;
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.xcontent.XContentType;
import org.hamcrest.Matchers;

import static org.elasticsearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;

/**
* Test that index enough data to trigger the creation of Cuckoo filters.
Expand Down Expand Up @@ -64,4 +74,33 @@ private void assertNumRareTerms(int maxDocs, int rareTerms) {
}
);
}

public void testGlobalAggregationWithScore() {
createIndex("global", Settings.EMPTY, "_doc", "keyword", "type=keyword");
prepareIndex("global").setSource("keyword", "a").setRefreshPolicy(IMMEDIATE).get();
prepareIndex("global").setSource("keyword", "c").setRefreshPolicy(IMMEDIATE).get();
prepareIndex("global").setSource("keyword", "e").setRefreshPolicy(IMMEDIATE).get();
GlobalAggregationBuilder globalBuilder = new GlobalAggregationBuilder("global").subAggregation(
new RareTermsAggregationBuilder("terms").field("keyword")
.subAggregation(
new RareTermsAggregationBuilder("sub_terms").field("keyword")
.subAggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"))
)
);
assertNoFailuresAndResponse(client().prepareSearch("global").addAggregation(globalBuilder), response -> {
InternalGlobal result = response.getAggregations().get("global");
InternalMultiBucketAggregation<?, ?> terms = result.getAggregations().get("terms");
assertThat(terms.getBuckets().size(), equalTo(3));
for (MultiBucketsAggregation.Bucket bucket : terms.getBuckets()) {
InternalMultiBucketAggregation<?, ?> subTerms = bucket.getAggregations().get("sub_terms");
assertThat(subTerms.getBuckets().size(), equalTo(1));
MultiBucketsAggregation.Bucket subBucket = subTerms.getBuckets().get(0);
InternalTopHits topHits = subBucket.getAggregations().get("top_hits");
assertThat(topHits.getHits().getHits().length, equalTo(1));
for (SearchHit hit : topHits.getHits()) {
assertThat(hit.getScore(), greaterThan(0f));
}
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,24 @@
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.aggregations.AggregationExecutionException;
import org.elasticsearch.search.aggregations.AggregationTestScriptsPlugin;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.Aggregator.SubAggCollectionMode;
import org.elasticsearch.search.aggregations.BucketOrder;
import org.elasticsearch.search.aggregations.InternalMultiBucketAggregation;
import org.elasticsearch.search.aggregations.bucket.AbstractTermsTestCase;
import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.global.InternalGlobal;
import org.elasticsearch.search.aggregations.metrics.Avg;
import org.elasticsearch.search.aggregations.metrics.ExtendedStats;
import org.elasticsearch.search.aggregations.metrics.InternalTopHits;
import org.elasticsearch.search.aggregations.metrics.Stats;
import org.elasticsearch.search.aggregations.metrics.Sum;
import org.elasticsearch.search.aggregations.metrics.TopHitsAggregationBuilder;
import org.elasticsearch.search.aggregations.support.ValueType;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.ESIntegTestCase;
Expand Down Expand Up @@ -63,6 +71,7 @@
import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.startsWith;
import static org.hamcrest.core.IsNull.notNullValue;
Expand Down Expand Up @@ -1376,4 +1385,46 @@ private void assertOrderByKeyResponse(
}
);
}

public void testGlobalAggregationWithScore() throws Exception {
assertAcked(prepareCreate("global").setMapping("keyword", "type=keyword"));
indexRandom(
true,
prepareIndex("global").setSource("keyword", "a"),
prepareIndex("global").setSource("keyword", "c"),
prepareIndex("global").setSource("keyword", "e")
);
String executionHint = randomFrom(TermsAggregatorFactory.ExecutionMode.values()).toString();
Aggregator.SubAggCollectionMode collectionMode = randomFrom(Aggregator.SubAggCollectionMode.values());
GlobalAggregationBuilder globalBuilder = new GlobalAggregationBuilder("global").subAggregation(
new TermsAggregationBuilder("terms").userValueTypeHint(ValueType.STRING)
.executionHint(executionHint)
.collectMode(collectionMode)
.field("keyword")
.order(BucketOrder.key(true))
.subAggregation(
new TermsAggregationBuilder("sub_terms").userValueTypeHint(ValueType.STRING)
.executionHint(executionHint)
.collectMode(collectionMode)
.field("keyword")
.order(BucketOrder.key(true))
.subAggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"))
)
);
assertNoFailuresAndResponse(prepareSearch("global").addAggregation(globalBuilder), response -> {
InternalGlobal result = response.getAggregations().get("global");
InternalMultiBucketAggregation<?, ?> terms = result.getAggregations().get("terms");
assertThat(terms.getBuckets().size(), equalTo(3));
for (MultiBucketsAggregation.Bucket bucket : terms.getBuckets()) {
InternalMultiBucketAggregation<?, ?> subTerms = bucket.getAggregations().get("sub_terms");
assertThat(subTerms.getBuckets().size(), equalTo(1));
MultiBucketsAggregation.Bucket subBucket = subTerms.getBuckets().get(0);
InternalTopHits topHits = subBucket.getAggregations().get("top_hits");
assertThat(topHits.getHits().getHits().length, equalTo(1));
for (SearchHit hit : topHits.getHits()) {
assertThat(hit.getScore(), greaterThan(0f));
}
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*/
package org.elasticsearch.search.aggregations.metrics;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.join.ScoreMode;
import org.apache.lucene.util.ArrayUtil;
Expand All @@ -20,6 +21,7 @@
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.seqno.SequenceNumbers;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.script.MockScriptEngine;
import org.elasticsearch.script.MockScriptPlugin;
import org.elasticsearch.script.Script;
Expand All @@ -34,15 +36,21 @@
import org.elasticsearch.search.aggregations.bucket.nested.Nested;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregatorFactory.ExecutionMode;
import org.elasticsearch.search.fetch.FetchSubPhase;
import org.elasticsearch.search.fetch.FetchSubPhaseProcessor;
import org.elasticsearch.search.fetch.StoredFieldsSpec;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightField;
import org.elasticsearch.search.lookup.FieldLookup;
import org.elasticsearch.search.lookup.LeafSearchLookup;
import org.elasticsearch.search.rescore.QueryRescorerBuilder;
import org.elasticsearch.search.sort.ScriptSortBuilder.ScriptSortType;
import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
Expand Down Expand Up @@ -87,7 +95,7 @@ public class TopHitsIT extends ESIntegTestCase {

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Collections.singleton(CustomScriptPlugin.class);
return List.of(CustomScriptPlugin.class, FetchPlugin.class);
}

public static class CustomScriptPlugin extends MockScriptPlugin {
Expand All @@ -110,7 +118,7 @@ public static String randomExecutionHint() {

@Override
public void setupSuiteScopeCluster() throws Exception {
assertAcked(prepareCreate("idx").setMapping(TERMS_AGGS_FIELD, "type=keyword"));
assertAcked(prepareCreate("idx").setMapping(TERMS_AGGS_FIELD, "type=keyword", "text", "type=text,store=true"));
assertAcked(prepareCreate("field-collapsing").setMapping("group", "type=keyword"));
createIndex("empty");
assertAcked(
Expand Down Expand Up @@ -592,7 +600,7 @@ public void testFieldCollapsing() throws Exception {
);
}

public void testFetchFeatures() {
public void testFetchFeatures() throws IOException {
final boolean seqNoAndTerm = randomBoolean();
assertNoFailuresAndResponse(
prepareSearch("idx").setQuery(matchQuery("text", "text").queryName("test"))
Expand Down Expand Up @@ -642,19 +650,14 @@ public void testFetchFeatures() {

assertThat(hit.getMatchedQueries()[0], equalTo("test"));

DocumentField field1 = hit.field("field1");
assertThat(field1.getValue(), equalTo(5L));

DocumentField field2 = hit.field("field2");
assertThat(field2.getValue(), equalTo(2.71f));

assertThat(hit.getSourceAsMap().get("text").toString(), equalTo("some text to entertain"));

field2 = hit.field("script");
assertThat(field2.getValue().toString(), equalTo("5"));
assertThat(hit.field("field1").getValue(), equalTo(5L));
assertThat(hit.field("field2").getValue(), equalTo(2.71f));
assertThat(hit.field("script").getValue().toString(), equalTo("5"));

assertThat(hit.getSourceAsMap().size(), equalTo(1));
assertThat(hit.getSourceAsMap().get("text").toString(), equalTo("some text to entertain"));
assertEquals("some text to entertain", hit.getFields().get("text").getValue());
assertEquals("some text to entertain", hit.getFields().get("text_stored_lookup").getValue());
}
}
);
Expand Down Expand Up @@ -1263,4 +1266,37 @@ public void testWithRescore() {
}
);
}

public static class FetchPlugin extends Plugin implements SearchPlugin {
@Override
public List<FetchSubPhase> getFetchSubPhases(FetchPhaseConstructionContext context) {
return Collections.singletonList(fetchContext -> {
if (fetchContext.getIndexName().equals("idx")) {
return new FetchSubPhaseProcessor() {

private LeafSearchLookup leafSearchLookup;

@Override
public void setNextReader(LeafReaderContext ctx) {
leafSearchLookup = fetchContext.getSearchExecutionContext().lookup().getLeafSearchLookup(ctx);
}

@Override
public void process(FetchSubPhase.HitContext hitContext) {
leafSearchLookup.setDocument(hitContext.docId());
FieldLookup fieldLookup = leafSearchLookup.fields().get("text");
hitContext.hit()
.setDocumentField("text_stored_lookup", new DocumentField("text_stored_lookup", fieldLookup.getValues()));
}

@Override
public StoredFieldsSpec storedFieldsSpec() {
return StoredFieldsSpec.NO_REQUIREMENTS;
}
};
}
return null;
});
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.common.util.LongObjectPagedHashMap;
import org.elasticsearch.common.util.LongObjectPagedHashMap.Cursor;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.aggregations.AggregationExecutionContext;
Expand Down Expand Up @@ -191,8 +192,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) throws IOE
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
docIdsToLoad[i] = topDocs.scoreDocs[i].doc;
}
subSearchContext.fetchPhase().execute(subSearchContext, docIdsToLoad);
FetchSearchResult fetchResult = subSearchContext.fetchResult();
FetchSearchResult fetchResult = runFetchPhase(subSearchContext, docIdsToLoad);
if (fetchProfiles != null) {
fetchProfiles.add(fetchResult.profileResult());
}
Expand All @@ -216,6 +216,19 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) throws IOE
);
}

private static FetchSearchResult runFetchPhase(SubSearchContext subSearchContext, int[] docIdsToLoad) {
// Fork the search execution context for each slice, because the fetch phase does not support concurrent execution yet.
SearchExecutionContext searchExecutionContext = new SearchExecutionContext(subSearchContext.getSearchExecutionContext());
SubSearchContext fetchSubSearchContext = new SubSearchContext(subSearchContext) {
@Override
public SearchExecutionContext getSearchExecutionContext() {
return searchExecutionContext;
}
};
fetchSubSearchContext.fetchPhase().execute(fetchSubSearchContext, docIdsToLoad);
return fetchSubSearchContext.fetchResult();
}

@Override
public InternalTopHits buildEmptyAggregation() {
TopDocs topDocs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ private SearchHits buildSearchHits(SearchContext context, int[] docIdsToLoad, Pr

PreloadedSourceProvider sourceProvider = new PreloadedSourceProvider();
PreloadedFieldLookupProvider fieldLookupProvider = new PreloadedFieldLookupProvider();
// The following relies on the fact that we fetch sequentially one segment after another, from a single thread
// This needs to be revised once we add concurrency to the fetch phase, and needs a work-around for situations
// where we run fetch as part of the query phase, where inter-segment concurrency is leveraged.
// One problem is the global setLookupProviders call against the shared execution context.
// Another problem is that the above provider implementations are not thread-safe
context.getSearchExecutionContext().setLookupProviders(sourceProvider, ctx -> fieldLookupProvider);

List<FetchSubPhaseProcessor> processors = getProcessors(context.shardTarget(), fetchContext, profiler);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
public class SubSearchContext extends FilteredSearchContext {

// By default return 3 hits per bucket. A higher default would make the response really large by default, since
// the to hits are returned per bucket.
// the top hits are returned per bucket.
private static final int DEFAULT_SIZE = 3;

private int from;
Expand Down Expand Up @@ -62,6 +62,25 @@ public SubSearchContext(SearchContext context) {
this.querySearchResult = new QuerySearchResult();
}

public SubSearchContext(SubSearchContext subSearchContext) {
this((SearchContext) subSearchContext);
this.from = subSearchContext.from;
this.size = subSearchContext.size;
this.sort = subSearchContext.sort;
this.parsedQuery = subSearchContext.parsedQuery;
this.query = subSearchContext.query;
this.storedFields = subSearchContext.storedFields;
this.scriptFields = subSearchContext.scriptFields;
this.fetchSourceContext = subSearchContext.fetchSourceContext;
this.docValuesContext = subSearchContext.docValuesContext;
this.fetchFieldsContext = subSearchContext.fetchFieldsContext;
this.highlight = subSearchContext.highlight;
this.explain = subSearchContext.explain;
this.trackScores = subSearchContext.trackScores;
this.version = subSearchContext.version;
this.seqNoAndPrimaryTerm = subSearchContext.seqNoAndPrimaryTerm;
}

@Override
public void preProcess() {}

Expand Down
Loading

0 comments on commit d6582cf

Please sign in to comment.