From f6a97edb82e38c0cc9004cfe46da6cf8e600ebc1 Mon Sep 17 00:00:00 2001 From: Jay Deng Date: Fri, 15 Sep 2023 17:06:05 -0700 Subject: [PATCH] Fix concurrent search NPE with track_total_hits, size=0, and terminate_after --- .../search/simple/SimpleSearchIT.java | 39 +++++++++++++++++++ .../search/query/TopDocsCollectorContext.java | 14 +++++-- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/server/src/internalClusterTest/java/org/opensearch/search/simple/SimpleSearchIT.java b/server/src/internalClusterTest/java/org/opensearch/search/simple/SimpleSearchIT.java index 0e6073ad11689..6e48360fb3a02 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/simple/SimpleSearchIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/simple/SimpleSearchIT.java @@ -292,6 +292,45 @@ public void testSimpleTerminateAfterCount() throws Exception { assertFalse(searchResponse.isTerminatedEarly()); } + public void testSimpleTerminateAfterCountWithSizeAndTrackHits() throws Exception { + prepareCreate("test").setSettings(Settings.builder().put(SETTING_NUMBER_OF_SHARDS, 1).put(SETTING_NUMBER_OF_REPLICAS, 0)).get(); + ensureGreen(); + int max = randomIntBetween(3, 29); + List docbuilders = new ArrayList<>(max); + + for (int i = 1; i <= max; i++) { + String id = String.valueOf(i); + docbuilders.add(client().prepareIndex("test").setId(id).setSource("field", i)); + } + + indexRandom(true, docbuilders); + ensureGreen(); + refresh(); + + SearchResponse searchResponse; + for (int i = 1; i < max; i++) { + searchResponse = client().prepareSearch("test") + .setQuery(QueryBuilders.matchAllQuery()) + .setTerminateAfter(i) + .setSize(0) + .setTrackTotalHits(true) + .get(); + assertEquals(0, searchResponse.getFailedShards()); + assertHitCount(searchResponse, i); + assertTrue(searchResponse.isTerminatedEarly()); + + searchResponse = client().prepareSearch("test") + .setQuery(QueryBuilders.matchAllQuery()) + .setTerminateAfter(i) + .setSize(randomIntBetween(1, max)) + .setTrackTotalHits(true) + .get(); + assertEquals(0, searchResponse.getFailedShards()); + assertHitCount(searchResponse, i); + assertTrue(searchResponse.isTerminatedEarly()); + } + } + public void testSimpleIndexSortEarlyTerminate() throws Exception { prepareCreate("test").setSettings( Settings.builder().put(SETTING_NUMBER_OF_SHARDS, 1).put(SETTING_NUMBER_OF_REPLICAS, 0).put("index.sort.field", "rank") diff --git a/server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java b/server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java index 39c34f7c0d5d5..8a6e600cb5130 100644 --- a/server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java +++ b/server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java @@ -136,9 +136,10 @@ private EmptyTopDocsCollectorContext( Query query, @Nullable SortAndFormats sortAndFormats, int trackTotalHitsUpTo, - boolean hasFilterCollector + boolean hasFilterCollector, + int numDocs ) throws IOException { - super(REASON_SEARCH_COUNT, 0); + super(REASON_SEARCH_COUNT, numDocs); this.sort = sortAndFormats == null ? null : sortAndFormats.sort; this.trackTotalHitsUpTo = trackTotalHitsUpTo; if (this.trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED) { @@ -189,6 +190,12 @@ CollectorManager createManager(CollectorManager( + new TotalHitCountCollectorManager.Empty(new TotalHits(numHits, TotalHits.Relation.EQUAL_TO), sort), + numHits, + false + ); } } else { manager = new EarlyTerminatingCollectorManager<>( @@ -778,7 +785,8 @@ public static TopDocsCollectorContext createTopDocsCollectorContext(SearchContex query, searchContext.sort(), searchContext.trackTotalHitsUpTo(), - hasFilterCollector + hasFilterCollector, + totalNumDocs ); } else if (searchContext.scrollContext() != null) { // we can disable the tracking of total hits after the initial scroll query