diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java index 62ddb64f6..77ca3e64e 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java @@ -6,7 +6,9 @@ package org.opensearch.neuralsearch.query; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import java.io.IOException; import java.util.Arrays; @@ -21,6 +23,7 @@ import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; import org.apache.lucene.tests.util.TestUtil; @@ -169,6 +172,63 @@ public void testWithRandomDocuments_whenMultipleScorersAndSomeScorersEmpty_thenR testWithQuery(docs, scores, hybridQueryScorer); } + @SneakyThrows + public void testMaxScore_whenMultipleScorers_thenSuccessful() { + int maxDocId = TestUtil.nextInt(random(), 10, 10_000); + Pair docsAndScores = generateDocuments(maxDocId); + int[] docs = docsAndScores.getLeft(); + float[] scores = docsAndScores.getRight(); + + Weight weight = mock(Weight.class); + + HybridQueryScorer hybridQueryScorerWithAllNonNullSubScorers = new HybridQueryScorer( + weight, + Arrays.asList( + scorer(docs, scores, fakeWeight(new MatchAllDocsQuery())), + scorer(docs, scores, fakeWeight(new MatchNoDocsQuery())) + ) + ); + + float maxScore = hybridQueryScorerWithAllNonNullSubScorers.getMaxScore(Integer.MAX_VALUE); + assertTrue(maxScore > 0.0f); + + HybridQueryScorer hybridQueryScorerWithSomeNullSubScorers = new HybridQueryScorer( + weight, + Arrays.asList(null, scorer(docs, scores, fakeWeight(new MatchAllDocsQuery())), null) + ); + + maxScore = hybridQueryScorerWithSomeNullSubScorers.getMaxScore(Integer.MAX_VALUE); + assertTrue(maxScore > 0.0f); + + HybridQueryScorer hybridQueryScorerWithAllNullSubScorers = new HybridQueryScorer(weight, Arrays.asList(null, null)); + + maxScore = hybridQueryScorerWithAllNullSubScorers.getMaxScore(Integer.MAX_VALUE); + assertEquals(0.0f, maxScore, 0.0f); + } + + @SneakyThrows + public void testMaxScoreFailures_whenScorerThrowsException_thenFail() { + int maxDocId = TestUtil.nextInt(random(), 10, 10_000); + Pair docsAndScores = generateDocuments(maxDocId); + int[] docs = docsAndScores.getLeft(); + float[] scores = docsAndScores.getRight(); + + Weight weight = mock(Weight.class); + + Scorer scorer = mock(Scorer.class); + when(scorer.getWeight()).thenReturn(fakeWeight(new MatchAllDocsQuery())); + when(scorer.iterator()).thenReturn(iterator(docs)); + when(scorer.getMaxScore(anyInt())).thenThrow(new IOException("Test exception")); + + HybridQueryScorer hybridQueryScorerWithAllNonNullSubScorers = new HybridQueryScorer(weight, Arrays.asList(scorer)); + + RuntimeException runtimeException = expectThrows( + RuntimeException.class, + () -> hybridQueryScorerWithAllNonNullSubScorers.getMaxScore(Integer.MAX_VALUE) + ); + assertTrue(runtimeException.getMessage().contains("Test exception")); + } + private Pair generateDocuments(int maxDocId) { final int numDocs = RandomizedTest.randomIntBetween(1, maxDocId / 2); final int[] docs = new int[numDocs]; diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java index c876621a2..8656d7f04 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java @@ -61,11 +61,11 @@ public void testScorerIterator_whenExecuteQuery_thenScorerIteratorSuccessful() { List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) ); IndexSearcher searcher = newSearcher(reader); - Weight weight = searcher.createWeight(hybridQueryWithTerm, ScoreMode.COMPLETE, 1.0f); + Weight weight = hybridQueryWithTerm.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f); assertNotNull(weight); - LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); Scorer scorer = weight.scorer(leafReaderContext); assertNotNull(scorer);