Skip to content

Commit

Permalink
Added more unit tests, fixed flaky unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Nov 15, 2023
1 parent dff866d commit 438ae0a
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
package org.opensearch.neuralsearch.query;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.util.Arrays;
Expand All @@ -21,6 +23,7 @@
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.tests.util.TestUtil;

Expand Down Expand Up @@ -169,6 +172,63 @@ public void testWithRandomDocuments_whenMultipleScorersAndSomeScorersEmpty_thenR
testWithQuery(docs, scores, hybridQueryScorer);
}

@SneakyThrows
public void testMaxScore_whenMultipleScorers_thenSuccessful() {
int maxDocId = TestUtil.nextInt(random(), 10, 10_000);
Pair<int[], float[]> docsAndScores = generateDocuments(maxDocId);
int[] docs = docsAndScores.getLeft();
float[] scores = docsAndScores.getRight();

Weight weight = mock(Weight.class);

HybridQueryScorer hybridQueryScorerWithAllNonNullSubScorers = new HybridQueryScorer(
weight,
Arrays.asList(
scorer(docs, scores, fakeWeight(new MatchAllDocsQuery())),
scorer(docs, scores, fakeWeight(new MatchNoDocsQuery()))
)
);

float maxScore = hybridQueryScorerWithAllNonNullSubScorers.getMaxScore(Integer.MAX_VALUE);
assertTrue(maxScore > 0.0f);

HybridQueryScorer hybridQueryScorerWithSomeNullSubScorers = new HybridQueryScorer(
weight,
Arrays.asList(null, scorer(docs, scores, fakeWeight(new MatchAllDocsQuery())), null)
);

maxScore = hybridQueryScorerWithSomeNullSubScorers.getMaxScore(Integer.MAX_VALUE);
assertTrue(maxScore > 0.0f);

HybridQueryScorer hybridQueryScorerWithAllNullSubScorers = new HybridQueryScorer(weight, Arrays.asList(null, null));

maxScore = hybridQueryScorerWithAllNullSubScorers.getMaxScore(Integer.MAX_VALUE);
assertEquals(0.0f, maxScore, 0.0f);
}

@SneakyThrows
public void testMaxScoreFailures_whenScorerThrowsException_thenFail() {
int maxDocId = TestUtil.nextInt(random(), 10, 10_000);
Pair<int[], float[]> docsAndScores = generateDocuments(maxDocId);
int[] docs = docsAndScores.getLeft();
float[] scores = docsAndScores.getRight();

Weight weight = mock(Weight.class);

Scorer scorer = mock(Scorer.class);
when(scorer.getWeight()).thenReturn(fakeWeight(new MatchAllDocsQuery()));
when(scorer.iterator()).thenReturn(iterator(docs));
when(scorer.getMaxScore(anyInt())).thenThrow(new IOException("Test exception"));

HybridQueryScorer hybridQueryScorerWithAllNonNullSubScorers = new HybridQueryScorer(weight, Arrays.asList(scorer));

RuntimeException runtimeException = expectThrows(
RuntimeException.class,
() -> hybridQueryScorerWithAllNonNullSubScorers.getMaxScore(Integer.MAX_VALUE)
);
assertTrue(runtimeException.getMessage().contains("Test exception"));
}

private Pair<int[], float[]> generateDocuments(int maxDocId) {
final int numDocs = RandomizedTest.randomIntBetween(1, maxDocId / 2);
final int[] docs = new int[numDocs];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ public void testScorerIterator_whenExecuteQuery_thenScorerIteratorSuccessful() {
List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext))
);
IndexSearcher searcher = newSearcher(reader);
Weight weight = searcher.createWeight(hybridQueryWithTerm, ScoreMode.COMPLETE, 1.0f);
Weight weight = hybridQueryWithTerm.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f);

assertNotNull(weight);

LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0);
Scorer scorer = weight.scorer(leafReaderContext);

assertNotNull(scorer);
Expand Down

0 comments on commit 438ae0a

Please sign in to comment.