diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index 1d3bc29e9..24ebebe5b 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -513,81 +513,48 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); - when(indexReader.numDocs()).thenReturn(2); + when(indexReader.numDocs()).thenReturn(3); when(indexSearcher.getIndexReader()).thenReturn(indexReader); when(searchContext.searcher()).thenReturn(indexSearcher); - when(searchContext.size()).thenReturn(1); + when(searchContext.size()).thenReturn(10); Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); - when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - Directory directory = newDirectory(); - final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); - ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); - ft.setOmitNorms(random().nextBoolean()); + ft.setIndexOptions(IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(false); ft.freeze(); - int docId1 = RandomizedTest.randomInt(); - int docId2 = RandomizedTest.randomInt(); - int docId3 = RandomizedTest.randomInt(); - - w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); - w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, 1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, 2, TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, 3, TEST_DOC_TEXT3, ft)); w.flush(); w.commit(); - SearchContext searchContext2 = mock(SearchContext.class); - - ContextIndexSearcher indexSearcher2 = mock(ContextIndexSearcher.class); - IndexReader indexReader2 = mock(IndexReader.class); - when(indexReader2.numDocs()).thenReturn(1); - when(indexSearcher2.getIndexReader()).thenReturn(indexReader); - when(searchContext2.searcher()).thenReturn(indexSearcher2); - when(searchContext2.size()).thenReturn(1); - - when(searchContext2.queryCollectorManagers()).thenReturn(new HashMap<>()); - when(searchContext2.shouldUseConcurrentSearch()).thenReturn(true); - - when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - - Directory directory2 = newDirectory(); - final IndexWriter w2 = new IndexWriter(directory2, newIndexWriterConfig(new MockAnalyzer(random()))); - FieldType ft2 = new FieldType(TextField.TYPE_NOT_STORED); - ft2.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); - ft2.setOmitNorms(random().nextBoolean()); - ft2.freeze(); - - w2.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); - w2.flush(); - w2.commit(); - IndexReader reader = DirectoryReader.open(w); IndexSearcher searcher = newSearcher(reader); - IndexReader reader2 = DirectoryReader.open(w2); - IndexSearcher searcher2 = newSearcher(reader2); CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); HybridTopScoreDocCollector collector1 = (HybridTopScoreDocCollector) hybridCollectorManager.newCollector(); HybridTopScoreDocCollector collector2 = (HybridTopScoreDocCollector) hybridCollectorManager.newCollector(); - Weight weight1 = new HybridQueryWeight(hybridQueryWithTerm, searcher, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); - Weight weight2 = new HybridQueryWeight(hybridQueryWithTerm, searcher2, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); - collector1.setWeight(weight1); - collector2.setWeight(weight2); + Weight weight = new HybridQueryWeight(hybridQueryWithTerm, searcher, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + collector1.setWeight(weight); + collector2.setWeight(weight); + LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); LeafCollector leafCollector1 = collector1.getLeafCollector(leafReaderContext); + LeafCollector leafCollector2 = collector2.getLeafCollector(leafReaderContext); - LeafReaderContext leafReaderContext2 = searcher2.getIndexReader().leaves().get(0); - LeafCollector leafCollector2 = collector2.getLeafCollector(leafReaderContext2); - BulkScorer scorer = weight1.bulkScorer(leafReaderContext); + BulkScorer scorer = weight.bulkScorer(leafReaderContext); scorer.score(leafCollector1, leafReaderContext.reader().getLiveDocs()); + scorer.score(leafCollector2, leafReaderContext.reader().getLiveDocs()); + leafCollector1.finish(); - BulkScorer scorer2 = weight2.bulkScorer(leafReaderContext2); - scorer2.score(leafCollector2, leafReaderContext2.reader().getLiveDocs()); leafCollector2.finish(); Object results = hybridCollectorManager.reduce(List.of(collector1, collector2)); @@ -603,6 +570,7 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD assertEquals(TotalHits.Relation.EQUAL_TO, topDocsAndMaxScore.topDocs.totalHits.relation); float maxScore = topDocsAndMaxScore.maxScore; assertTrue(maxScore > 0); + ScoreDoc[] scoreDocs = topDocsAndMaxScore.topDocs.scoreDocs; assertEquals(6, scoreDocs.length); assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, DELTA_FOR_ASSERTION); @@ -610,7 +578,6 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD assertTrue(scoreDocs[2].score > 0); assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[3].score, DELTA_FOR_ASSERTION); assertTrue(scoreDocs[4].score > 0); - assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[5].score, DELTA_FOR_ASSERTION); // we have to assert that one of hits is max score because scores are generated for each run and order is not guaranteed assertTrue(Float.compare(scoreDocs[2].score, maxScore) == 0 || Float.compare(scoreDocs[4].score, maxScore) == 0); @@ -618,9 +585,6 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD w.close(); reader.close(); directory.close(); - w2.close(); - reader2.close(); - directory2.close(); } @SneakyThrows