diff --git a/CHANGELOG.md b/CHANGELOG.md index 36c6be493..479bf1877 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.14...2.x) ### Features ### Enhancements +- Pass empty doc collector instead of top docs collector to improve hybrid query latencies by 20% ([#731](https://github.com/opensearch-project/neural-search/pull/731)) ### Bug Fixes - Fix multi node "no such index" error in text chunking processor ([#713](https://github.com/opensearch-project/neural-search/pull/713)) ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext.java b/src/main/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext.java new file mode 100644 index 000000000..71843d5f2 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import org.apache.lucene.search.Query; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.ConcurrentQueryPhaseSearcher; +import org.opensearch.search.query.QueryCollectorContext; + +import java.io.IOException; +import java.util.LinkedList; + +/** + * Class that inherits ConcurrentQueryPhaseSearcher implementation but calls its search with only + * empty query collector context + */ +public class ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext extends ConcurrentQueryPhaseSearcher { + + @Override + protected boolean searchWithCollector( + SearchContext searchContext, + ContextIndexSearcher searcher, + Query query, + LinkedList collectors, + boolean hasFilterCollector, + boolean hasTimeout + ) throws IOException { + return searchWithCollector( + searchContext, + searcher, + query, + collectors, + QueryCollectorContext.EMPTY_CONTEXT, + hasFilterCollector, + hasTimeout + ); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext.java b/src/main/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext.java new file mode 100644 index 000000000..179f81e7f --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import org.apache.lucene.search.Query; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QueryCollectorContext; +import org.opensearch.search.query.QueryPhase; + +import java.io.IOException; +import java.util.LinkedList; + +/** + * Class that inherits DefaultQueryPhaseSearcher implementation but calls its search with only + * empty query collector context + */ +public class DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext extends QueryPhase.DefaultQueryPhaseSearcher { + + @Override + protected boolean searchWithCollector( + SearchContext searchContext, + ContextIndexSearcher searcher, + Query query, + LinkedList collectors, + boolean hasFilterCollector, + boolean hasTimeout + ) throws IOException { + return searchWithCollector( + searchContext, + searcher, + query, + collectors, + QueryCollectorContext.EMPTY_CONTEXT, + hasFilterCollector, + hasTimeout + ); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index b97134f8f..7b96ebff2 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -21,6 +21,7 @@ import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QueryCollectorContext; import org.opensearch.search.query.QueryPhase; +import org.opensearch.search.query.QueryPhaseSearcher; import org.opensearch.search.query.QueryPhaseSearcherWrapper; import lombok.extern.log4j.Log4j2; @@ -36,6 +37,14 @@ @Log4j2 public class HybridQueryPhaseSearcher extends QueryPhaseSearcherWrapper { + private final QueryPhaseSearcher defaultQueryPhaseSearcherWithEmptyCollectorContext; + private final QueryPhaseSearcher concurrentQueryPhaseSearcherWithEmptyCollectorContext; + + public HybridQueryPhaseSearcher() { + this.defaultQueryPhaseSearcherWithEmptyCollectorContext = new DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext(); + this.concurrentQueryPhaseSearcherWithEmptyCollectorContext = new ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext(); + } + public boolean searchWith( final SearchContext searchContext, final ContextIndexSearcher searcher, @@ -49,10 +58,17 @@ public boolean searchWith( return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } else { Query hybridQuery = extractHybridQuery(searchContext, query); - return super.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); + QueryPhaseSearcher queryPhaseSearcher = getQueryPhaseSearcher(searchContext); + return queryPhaseSearcher.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); } } + private QueryPhaseSearcher getQueryPhaseSearcher(final SearchContext searchContext) { + return searchContext.shouldUseConcurrentSearch() + ? concurrentQueryPhaseSearcherWithEmptyCollectorContext + : defaultQueryPhaseSearcherWithEmptyCollectorContext; + } + private static boolean isWrappedHybridQuery(final Query query) { return query instanceof BooleanQuery && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java new file mode 100644 index 000000000..5ad641be2 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/query/ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import lombok.SneakyThrows; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.opensearch.action.OriginalIndices; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.shard.IndexShard; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QueryCollectorContext; +import org.opensearch.search.query.QuerySearchResult; + +import java.io.IOException; +import java.util.LinkedList; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +public class ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContextTests extends OpenSearchQueryTestCase { + private static final String TEXT_FIELD_NAME = "field"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String QUERY_TEXT1 = "hello"; + private static final Index dummyIndex = new Index("dummy", "dummy"); + + @SneakyThrows + public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { + ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext queryPhaseSearcher = spy( + new ConcurrentQueryPhaseSearcherWithEmptyQueryCollectorContext() + ); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(3); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); + queryBuilder.add(termSubQuery); + + Query query = queryBuilder.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + queryPhaseSearcher.aggregationProcessor(searchContext).preProcess(searchContext); + queryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + assertTrue(querySearchResult.hasConsumedTopDocs()); + + releaseResources(directory, w, reader); + } + + private void releaseResources(Directory directory, IndexWriter w, IndexReader reader) throws IOException { + w.close(); + reader.close(); + directory.close(); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java new file mode 100644 index 000000000..51572000c --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/query/DefaultQueryPhaseSearcherWithEmptyQueryCollectorContextTests.java @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import lombok.SneakyThrows; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.opensearch.action.OriginalIndices; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.shard.IndexShard; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QueryCollectorContext; +import org.opensearch.search.query.QuerySearchResult; + +import java.io.IOException; +import java.util.LinkedList; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +public class DefaultQueryPhaseSearcherWithEmptyQueryCollectorContextTests extends OpenSearchQueryTestCase { + private static final String TEXT_FIELD_NAME = "field"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String QUERY_TEXT1 = "hello"; + private static final Index dummyIndex = new Index("dummy", "dummy"); + + @SneakyThrows + public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { + DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext queryPhaseSearcher = spy( + new DefaultQueryPhaseSearcherWithEmptyQueryCollectorContext() + ); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(3); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); + queryBuilder.add(termSubQuery); + + Query query = queryBuilder.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + queryPhaseSearcher.aggregationProcessor(searchContext).preProcess(searchContext); + queryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + assertTrue(querySearchResult.hasConsumedTopDocs()); + + releaseResources(directory, w, reader); + } + + private void releaseResources(Directory directory, IndexWriter w, IndexReader reader) throws IOException { + w.close(); + reader.close(); + directory.close(); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index a938b2111..b606aac3e 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -20,13 +20,11 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; import java.io.IOException; -import java.util.Arrays; +import java.util.HashMap; import java.util.LinkedList; -import java.util.List; import java.util.Map; import java.util.Set; -import java.util.UUID; -import java.util.stream.Collectors; +import java.util.concurrent.ExecutorService; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; @@ -39,6 +37,8 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; @@ -58,8 +58,6 @@ import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.MatchAllQueryBuilder; -import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; @@ -78,6 +76,7 @@ import lombok.SneakyThrows; import org.opensearch.search.query.QueryCollectorContext; import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase { private static final String VECTOR_FIELD_NAME = "vectorField"; @@ -88,13 +87,7 @@ public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase { private static final String TEST_DOC_TEXT4 = "This is really nice place to be"; private static final String QUERY_TEXT1 = "hello"; private static final String QUERY_TEXT2 = "randomkeyword"; - private static final String QUERY_TEXT3 = "place"; private static final Index dummyIndex = new Index("dummy", "dummy"); - private static final String MODEL_ID = "mfgfgdsfgfdgsde"; - private static final int K = 10; - private static final QueryBuilder TEST_FILTER = new MatchAllQueryBuilder(); - private static final UUID INDEX_UUID = UUID.randomUUID(); - private static final String TEST_INDEX = "index"; @SneakyThrows public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { @@ -306,20 +299,22 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); + CollectorManager collectorManager = HybridCollectorManager + .createHybridCollectorManager(searchContext); + Map, CollectorManager> queryCollectorManagers = new HashMap<>(); + queryCollectorManagers.put(HybridCollectorManager.class, collectorManager); + when(searchContext.queryCollectorManagers()).thenReturn(queryCollectorManagers); + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + hybridQueryPhaseSearcher.aggregationProcessor(searchContext).postProcess(searchContext); assertNotNull(querySearchResult.topDocs()); TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); TopDocs topDocs = topDocsAndMaxScore.topDocs; - assertEquals(1, topDocs.totalHits.value); + assertEquals(0, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(1, scoreDocs.length); - ScoreDoc scoreDoc = scoreDocs[0]; - assertNotNull(scoreDoc); - int actualDocId = Integer.parseInt(reader.document(scoreDoc.doc).getField("id").stringValue()); - assertEquals(docId1, actualDocId); - assertTrue(scoreDoc.score > 0.0f); + assertEquals(0, scoreDocs.length); releaseResources(directory, w, reader); } @@ -340,13 +335,7 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes ft.setOmitNorms(random().nextBoolean()); ft.freeze(); int docId1 = RandomizedTest.randomInt(); - int docId2 = RandomizedTest.randomInt(); - int docId3 = RandomizedTest.randomInt(); - int docId4 = RandomizedTest.randomInt(); w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); - w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); - w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); - w.addDocument(getDocument(TEXT_FIELD_NAME, docId4, TEST_DOC_TEXT4, ft)); w.commit(); IndexReader reader = DirectoryReader.open(w); @@ -395,18 +384,22 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); + CollectorManager collectorManager = HybridCollectorManager + .createHybridCollectorManager(searchContext); + Map, CollectorManager> queryCollectorManagers = new HashMap<>(); + queryCollectorManagers.put(HybridCollectorManager.class, collectorManager); + when(searchContext.queryCollectorManagers()).thenReturn(queryCollectorManagers); + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + hybridQueryPhaseSearcher.aggregationProcessor(searchContext).postProcess(searchContext); assertNotNull(querySearchResult.topDocs()); TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); TopDocs topDocs = topDocsAndMaxScore.topDocs; - assertEquals(4, topDocs.totalHits.value); + assertEquals(0, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(4, scoreDocs.length); - List expectedIds = List.of(0, 1, 2, 3); - List actualDocIds = Arrays.stream(scoreDocs).map(sd -> sd.doc).collect(Collectors.toList()); - assertEquals(expectedIds, actualDocIds); + assertEquals(0, scoreDocs.length); releaseResources(directory, w, reader); } @@ -705,18 +698,22 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_then when(searchContext.query()).thenReturn(query); + CollectorManager collectorManager = HybridCollectorManager + .createHybridCollectorManager(searchContext); + Map, CollectorManager> queryCollectorManagers = new HashMap<>(); + queryCollectorManagers.put(HybridCollectorManager.class, collectorManager); + when(searchContext.queryCollectorManagers()).thenReturn(queryCollectorManagers); + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + hybridQueryPhaseSearcher.aggregationProcessor(searchContext).postProcess(searchContext); assertNotNull(querySearchResult.topDocs()); TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); TopDocs topDocs = topDocsAndMaxScore.topDocs; - assertTrue(topDocs.totalHits.value > 0); + assertEquals(0, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(1, scoreDocs.length); - ScoreDoc scoreDoc = scoreDocs[0]; - assertTrue(scoreDoc.score > 0); - assertEquals(0, scoreDoc.doc); + assertEquals(0, scoreDocs.length); releaseResources(directory, w, reader); } @@ -979,18 +976,22 @@ public void testAliasWithFilter_whenHybridWrappedIntoBoolBecauseOfIndexAlias_the when(searchContext.query()).thenReturn(query); when(searchContext.aliasFilter()).thenReturn(termFilter); + CollectorManager collectorManager = HybridCollectorManager + .createHybridCollectorManager(searchContext); + Map, CollectorManager> queryCollectorManagers = new HashMap<>(); + queryCollectorManagers.put(HybridCollectorManager.class, collectorManager); + when(searchContext.queryCollectorManagers()).thenReturn(queryCollectorManagers); + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + hybridQueryPhaseSearcher.aggregationProcessor(searchContext).postProcess(searchContext); assertNotNull(querySearchResult.topDocs()); TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); TopDocs topDocs = topDocsAndMaxScore.topDocs; - assertTrue(topDocs.totalHits.value > 0); + assertEquals(0, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(1, scoreDocs.length); - ScoreDoc scoreDoc = scoreDocs[0]; - assertTrue(scoreDoc.score > 0); - assertEquals(0, scoreDoc.doc); + assertEquals(0, scoreDocs.length); releaseResources(directory, w, reader); } @@ -1038,4 +1039,26 @@ private static IndexMetadata getIndexMetadata() { .build(); return indexMetadata; } + + private static ContextIndexSearcher newContextSearcher(IndexReader reader, ExecutorService executor) throws IOException { + SearchContext searchContext = mock(SearchContext.class); + IndexShard indexShard = mock(IndexShard.class); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(executor != null); + if (executor != null) { + when(searchContext.getTargetMaxSliceCount()).thenReturn(randomIntBetween(0, 2)); + } else { + when(searchContext.getTargetMaxSliceCount()).thenThrow(IllegalStateException.class); + } + return new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + executor, + searchContext + ); + } }