diff --git a/server/src/internalClusterTest/java/org/opensearch/search/simple/ParameterizedSimpleSearchIT.java b/server/src/internalClusterTest/java/org/opensearch/search/simple/ParameterizedSimpleSearchIT.java index 719b75079da92..eefa6928a0e2a 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/simple/ParameterizedSimpleSearchIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/simple/ParameterizedSimpleSearchIT.java @@ -308,6 +308,40 @@ public void testSimpleTerminateAfterCountWithSizeAndTrackHits() throws Exception assertEquals(0, searchResponse.getFailedShards()); } + public void testSimpleTerminateAfterCount() throws Exception { + prepareCreate("test").setSettings(Settings.builder().put(SETTING_NUMBER_OF_SHARDS, 1).put(SETTING_NUMBER_OF_REPLICAS, 0)).get(); + ensureGreen(); + int max = randomIntBetween(3, 29); + List docbuilders = new ArrayList<>(max); + + for (int i = 1; i <= max; i++) { + String id = String.valueOf(i); + docbuilders.add(client().prepareIndex("test").setId(id).setSource("field", i)); + } + + indexRandom(true, docbuilders); + ensureGreen(); + refresh(); + + SearchResponse searchResponse; + for (int i = 1; i < max; i++) { + searchResponse = client().prepareSearch("test") + .setQuery(QueryBuilders.rangeQuery("field").gte(1).lte(max)) + .setTerminateAfter(i) + .get(); + assertHitCount(searchResponse, i); + assertTrue(searchResponse.isTerminatedEarly()); + } + + searchResponse = client().prepareSearch("test") + .setQuery(QueryBuilders.rangeQuery("field").gte(1).lte(max)) + .setTerminateAfter(2 * max) + .get(); + + assertHitCount(searchResponse, max); + assertFalse(searchResponse.isTerminatedEarly()); + } + public void testSimpleIndexSortEarlyTerminate() throws Exception { prepareCreate("test").setSettings( Settings.builder().put(SETTING_NUMBER_OF_SHARDS, 1).put(SETTING_NUMBER_OF_REPLICAS, 0).put("index.sort.field", "rank") diff --git a/server/src/internalClusterTest/java/org/opensearch/search/simple/SimpleSearchIT.java b/server/src/internalClusterTest/java/org/opensearch/search/simple/SimpleSearchIT.java deleted file mode 100644 index 67e460653245e..0000000000000 --- a/server/src/internalClusterTest/java/org/opensearch/search/simple/SimpleSearchIT.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.search.simple; - -import org.opensearch.action.index.IndexRequestBuilder; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.common.settings.Settings; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.test.OpenSearchIntegTestCase; - -import java.util.ArrayList; -import java.util.List; - -import static org.opensearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS; -import static org.opensearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; -import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertHitCount; - -public class SimpleSearchIT extends OpenSearchIntegTestCase { - - // TODO: Move this test to ParameterizedSimpleSearchIT after https://github.com/opensearch-project/OpenSearch/issues/8371 - public void testSimpleTerminateAfterCount() throws Exception { - prepareCreate("test").setSettings(Settings.builder().put(SETTING_NUMBER_OF_SHARDS, 1).put(SETTING_NUMBER_OF_REPLICAS, 0)).get(); - ensureGreen(); - int max = randomIntBetween(3, 29); - List docbuilders = new ArrayList<>(max); - - for (int i = 1; i <= max; i++) { - String id = String.valueOf(i); - docbuilders.add(client().prepareIndex("test").setId(id).setSource("field", i)); - } - - indexRandom(true, docbuilders); - ensureGreen(); - refresh(); - - SearchResponse searchResponse; - for (int i = 1; i < max; i++) { - searchResponse = client().prepareSearch("test") - .setQuery(QueryBuilders.rangeQuery("field").gte(1).lte(max)) - .setTerminateAfter(i) - .get(); - assertHitCount(searchResponse, i); - assertTrue(searchResponse.isTerminatedEarly()); - } - - searchResponse = client().prepareSearch("test") - .setQuery(QueryBuilders.rangeQuery("field").gte(1).lte(max)) - .setTerminateAfter(2 * max) - .get(); - - assertHitCount(searchResponse, max); - assertFalse(searchResponse.isTerminatedEarly()); - } -} diff --git a/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java b/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java index aa86ed4e56801..77ea948f7a1c6 100644 --- a/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java +++ b/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java @@ -74,6 +74,7 @@ import org.opensearch.search.profile.query.ProfileWeight; import org.opensearch.search.profile.query.QueryProfiler; import org.opensearch.search.profile.query.QueryTimingType; +import org.opensearch.search.query.EarlyTerminatingCollector; import org.opensearch.search.query.QueryPhase; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.sort.FieldSortBuilder; @@ -292,7 +293,7 @@ protected void search(List leaves, Weight weight, Collector c private void searchLeaf(LeafReaderContext ctx, Weight weight, Collector collector) throws IOException { // Check if at all we need to call this leaf for collecting results. - if (canMatch(ctx) == false) { + if (canMatch(ctx) == false || searchContext.isTerminatedEarly()) { return; } @@ -310,6 +311,9 @@ private void searchLeaf(LeafReaderContext ctx, Weight weight, Collector collecto // there is no doc of interest in this reader context // continue with the following leaf return; + } catch (EarlyTerminatingCollector.EarlyTerminationException e) { + searchContext.setTerminatedEarly(true); + return; } catch (QueryPhase.TimeExceededException e) { searchContext.setSearchTimedOut(true); return; @@ -325,6 +329,9 @@ private void searchLeaf(LeafReaderContext ctx, Weight weight, Collector collecto } catch (CollectionTerminatedException e) { // collection was terminated prematurely // continue with the following leaf + } catch (EarlyTerminatingCollector.EarlyTerminationException e) { + searchContext.setTerminatedEarly(true); + return; } catch (QueryPhase.TimeExceededException e) { searchContext.setSearchTimedOut(true); return; @@ -344,6 +351,9 @@ private void searchLeaf(LeafReaderContext ctx, Weight weight, Collector collecto } catch (CollectionTerminatedException e) { // collection was terminated prematurely // continue with the following leaf + } catch (EarlyTerminatingCollector.EarlyTerminationException e) { + searchContext.setTerminatedEarly(true); + return; } catch (QueryPhase.TimeExceededException e) { searchContext.setSearchTimedOut(true); return; diff --git a/server/src/main/java/org/opensearch/search/internal/SearchContext.java b/server/src/main/java/org/opensearch/search/internal/SearchContext.java index dce6da897a74b..523234e35ae24 100644 --- a/server/src/main/java/org/opensearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/opensearch/search/internal/SearchContext.java @@ -119,6 +119,7 @@ public List toAggregators(Collection collectors) { private InnerHitsContext innerHitsContext; private volatile boolean searchTimedOut; + private volatile boolean terminatedEarly; protected SearchContext() {} @@ -136,6 +137,14 @@ public void setSearchTimedOut(boolean searchTimedOut) { this.searchTimedOut = searchTimedOut; } + public boolean isTerminatedEarly() { + return this.terminatedEarly; + } + + public void setTerminatedEarly(boolean terminatedEarly) { + this.terminatedEarly = terminatedEarly; + } + @Override public final void close() { if (closed.compareAndSet(false, true)) { diff --git a/server/src/main/java/org/opensearch/search/query/ConcurrentQueryPhaseSearcher.java b/server/src/main/java/org/opensearch/search/query/ConcurrentQueryPhaseSearcher.java index e22f766d3894c..15657ce3282af 100644 --- a/server/src/main/java/org/opensearch/search/query/ConcurrentQueryPhaseSearcher.java +++ b/server/src/main/java/org/opensearch/search/query/ConcurrentQueryPhaseSearcher.java @@ -94,6 +94,9 @@ private static boolean searchWithCollectorManager( if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) { queryResult.terminatedEarly(false); } + if (searchContext.isTerminatedEarly()) { + queryResult.terminatedEarly(true); + } return topDocsFactory.shouldRescore(); } diff --git a/server/src/main/java/org/opensearch/search/query/EarlyTerminatingCollector.java b/server/src/main/java/org/opensearch/search/query/EarlyTerminatingCollector.java index 5b86a70d64fff..a2c5b34395a3a 100644 --- a/server/src/main/java/org/opensearch/search/query/EarlyTerminatingCollector.java +++ b/server/src/main/java/org/opensearch/search/query/EarlyTerminatingCollector.java @@ -40,6 +40,7 @@ import org.apache.lucene.search.LeafCollector; import java.io.IOException; +import java.util.concurrent.atomic.AtomicLong; /** * A {@link Collector} that early terminates collection after maxCountHits docs have been collected. @@ -47,15 +48,15 @@ * @opensearch.internal */ public class EarlyTerminatingCollector extends FilterCollector { - static final class EarlyTerminationException extends RuntimeException { + public static final class EarlyTerminationException extends RuntimeException { EarlyTerminationException(String msg) { super(msg); } } private final int maxCountHits; - private int numCollected; - private boolean forceTermination; + private final AtomicLong numCollected; + private final boolean forceTermination; private boolean earlyTerminated; /** @@ -69,11 +70,19 @@ static final class EarlyTerminationException extends RuntimeException { super(delegate); this.maxCountHits = maxCountHits; this.forceTermination = forceTermination; + this.numCollected = new AtomicLong(); + } + + EarlyTerminatingCollector(final Collector delegate, int maxCountHits, boolean forceTermination, AtomicLong numCollected) { + super(delegate); + this.maxCountHits = maxCountHits; + this.forceTermination = forceTermination; + this.numCollected = numCollected; } @Override public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { - if (numCollected >= maxCountHits) { + if (numCollected.get() >= maxCountHits) { earlyTerminated = true; if (forceTermination) { throw new EarlyTerminationException("early termination [CountBased]"); @@ -84,7 +93,7 @@ public LeafCollector getLeafCollector(LeafReaderContext context) throws IOExcept return new FilterLeafCollector(super.getLeafCollector(context)) { @Override public void collect(int doc) throws IOException { - if (++numCollected > maxCountHits) { + if (numCollected.incrementAndGet() > maxCountHits) { earlyTerminated = true; if (forceTermination) { throw new EarlyTerminationException("early termination [CountBased]"); diff --git a/server/src/main/java/org/opensearch/search/query/EarlyTerminatingCollectorManager.java b/server/src/main/java/org/opensearch/search/query/EarlyTerminatingCollectorManager.java index e8153fd384b5d..5db7d82206a48 100644 --- a/server/src/main/java/org/opensearch/search/query/EarlyTerminatingCollectorManager.java +++ b/server/src/main/java/org/opensearch/search/query/EarlyTerminatingCollectorManager.java @@ -15,6 +15,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.concurrent.atomic.AtomicLong; /** * Manager for the EarlyTerminatingCollector @@ -29,16 +30,23 @@ public class EarlyTerminatingCollectorManager private final CollectorManager manager; private final int maxCountHits; private boolean forceTermination; + private final AtomicLong numCollected; EarlyTerminatingCollectorManager(CollectorManager manager, int maxCountHits, boolean forceTermination) { this.manager = manager; this.maxCountHits = maxCountHits; this.forceTermination = forceTermination; + this.numCollected = new AtomicLong(); } @Override public EarlyTerminatingCollector newCollector() throws IOException { - return new EarlyTerminatingCollector(manager.newCollector(), maxCountHits, false /* forced termination is not supported */); + return new EarlyTerminatingCollector( + manager.newCollector(), + maxCountHits, + forceTermination /* forced termination is not supported */, + numCollected + ); } @SuppressWarnings("unchecked") diff --git a/server/src/main/java/org/opensearch/search/query/QueryPhase.java b/server/src/main/java/org/opensearch/search/query/QueryPhase.java index f3cf2c13ecdef..914a34fa77afe 100644 --- a/server/src/main/java/org/opensearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/opensearch/search/query/QueryPhase.java @@ -277,6 +277,7 @@ static boolean executeInternal(SearchContext searchContext, QueryPhaseSearcher q } try { + // EarlyTerminationException gets swallowed here? boolean shouldRescore = queryPhaseSearcher.searchWith( searchContext, searcher, @@ -350,9 +351,8 @@ private static boolean searchWithCollector( queryCollector = QueryCollectorContext.createQueryCollector(collectors); } QuerySearchResult queryResult = searchContext.queryResult(); - try { - searcher.search(query, queryCollector); - } catch (EarlyTerminatingCollector.EarlyTerminationException e) { + searcher.search(query, queryCollector); + if (searchContext.isTerminatedEarly()) { queryResult.terminatedEarly(true); } if (searchContext.isSearchTimedOut()) {