diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index 81d1b552b..b63693f6a 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -4,6 +4,7 @@ */ package org.opensearch.neuralsearch.query; +import com.google.common.annotations.VisibleForTesting; import lombok.Getter; import lombok.extern.log4j.Log4j2; import org.apache.lucene.search.DisiPriorityQueue; @@ -30,7 +31,7 @@ * corresponds to order of sub-queries in an input Hybrid query. */ @Log4j2 -public final class HybridQueryScorer extends Scorer { +public class HybridQueryScorer extends Scorer { // score for each of sub-query in this hybrid query @Getter @@ -100,7 +101,8 @@ public float score() throws IOException { return score(getSubMatches()); } - private float score(DisiWrapper topList) throws IOException { + @VisibleForTesting + float score(DisiWrapper topList) throws IOException { float totalScore = 0.0f; for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) { // check if this doc has match in the subQuery. If not, add score as 0.0 and continue diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index 2a6fa49a3..a26dd8263 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -50,6 +50,7 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; @@ -756,6 +757,69 @@ public void testBoost_whenDefaultBoostSet_thenBuildSuccessfully() { assertNotNull(hybridQueryBuilder); } + @SneakyThrows + public void testBuild_whenValidParameters_thenCreateQuery() { + String queryText = "test query"; + String modelId = "test_model"; + String fieldName = "rank_features"; + + // Create mock context + QueryShardContext context = mock(QueryShardContext.class); + MappedFieldType fieldType = mock(MappedFieldType.class); + when(context.fieldMapper(fieldName)).thenReturn(fieldType); + when(fieldType.typeName()).thenReturn("rank_features"); + + // Create HybridQueryBuilder instance (no spy since it's final) + NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder(); + neuralSparseQueryBuilder.fieldName(fieldName) + .queryText(queryText) + .modelId(modelId) + .queryTokensSupplier(() -> Map.of("token1", 1.0f, "token2", 0.5f)); + HybridQueryBuilder builder = new HybridQueryBuilder().add(neuralSparseQueryBuilder); + + // Build query + Query query = builder.toQuery(context); + + // Verify + assertNotNull("Query should not be null", query); + assertTrue("Should be HybridQuery", query instanceof HybridQuery); + } + + @SneakyThrows + public void testDoEquals_whenSameParameters_thenEqual() { + // Create neural queries + NeuralQueryBuilder neuralQueryBuilder1 = new NeuralQueryBuilder().queryText("test").modelId("test_model"); + + NeuralQueryBuilder neuralQueryBuilder2 = new NeuralQueryBuilder().queryText("test").modelId("test_model"); + + // Create neural sparse queries with queryTokensSupplier + NeuralSparseQueryBuilder neuralSparseQueryBuilder1 = new NeuralSparseQueryBuilder().fieldName("test_field") + .queryText("test") + .modelId("test_model") + .queryTokensSupplier(() -> Map.of("token1", 1.0f)); + + NeuralSparseQueryBuilder neuralSparseQueryBuilder2 = new NeuralSparseQueryBuilder().fieldName("test_field") + .queryText("test") + .modelId("test_model") + .queryTokensSupplier(() -> Map.of("token1", 1.0f)); + + // Create builders + HybridQueryBuilder builder1 = new HybridQueryBuilder().add(neuralQueryBuilder1).add(neuralSparseQueryBuilder1); + + HybridQueryBuilder builder2 = new HybridQueryBuilder().add(neuralQueryBuilder2).add(neuralSparseQueryBuilder2); + + // Verify + assertTrue("Builders should be equal", builder1.equals(builder2)); + assertEquals("Hash codes should match", builder1.hashCode(), builder2.hashCode()); + } + + public void testValidate_whenInvalidParameters_thenThrowException() { + // Test null query builder + HybridQueryBuilder builderWithNull = new HybridQueryBuilder(); + IllegalArgumentException nullException = assertThrows(IllegalArgumentException.class, () -> builderWithNull.add(null)); + assertEquals("inner hybrid query clause cannot be null", nullException.getMessage()); + } + public void testVisit() { HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder().add(new NeuralQueryBuilder()).add(new NeuralSparseQueryBuilder()); List visitedQueries = new ArrayList<>(); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java index e7325055e..d02848894 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java @@ -7,19 +7,27 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; +import org.apache.lucene.search.DisiWrapper; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.search.Weight; @@ -28,6 +36,7 @@ import com.carrotsearch.randomizedtesting.RandomizedTest; import lombok.SneakyThrows; +import org.opensearch.neuralsearch.search.HybridDisiWrapper; public class HybridQueryScorerTests extends OpenSearchQueryTestCase { @@ -275,6 +284,380 @@ public void testApproximationIterator_whenSubScorerSupportsApproximation_thenSuc } } + @SneakyThrows + public void testScore_whenMultipleSubScorers_thenSumScores() { + // Create mock scorers with iterators + Scorer scorer1 = mock(Scorer.class); + DocIdSetIterator iterator1 = mock(DocIdSetIterator.class); + when(scorer1.iterator()).thenReturn(iterator1); + when(scorer1.docID()).thenReturn(1); + when(scorer1.score()).thenReturn(0.5f); + + Scorer scorer2 = mock(Scorer.class); + DocIdSetIterator iterator2 = mock(DocIdSetIterator.class); + when(scorer2.iterator()).thenReturn(iterator2); + when(scorer2.docID()).thenReturn(1); + when(scorer2.score()).thenReturn(0.3f); + + // Create DisiWrapper list + DisiWrapper wrapper1 = new DisiWrapper(scorer1); + wrapper1.next = new DisiWrapper(scorer2); + + Weight weight = mock(Weight.class); + HybridQueryScorer hybridScorer = new HybridQueryScorer(weight, Arrays.asList(scorer1, scorer2)); + float score = hybridScorer.score(wrapper1); + + assertEquals("Combined score should be sum of individual scores", 0.8f, score, 0.0001f); + } + + @SneakyThrows + public void testScore_whenNoMoreDocs_thenReturnZero() { + // Create mock scorer + Scorer scorer = mock(Scorer.class); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + when(scorer.iterator()).thenReturn(iterator); + + // Setup iterator behavior + when(iterator.docID()).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + when(iterator.cost()).thenReturn(1L); + + // Create TwoPhaseIterator if needed + TwoPhaseIterator twoPhase = mock(TwoPhaseIterator.class); + DocIdSetIterator approximation = mock(DocIdSetIterator.class); + when(approximation.docID()).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + when(approximation.cost()).thenReturn(1L); + when(twoPhase.approximation()).thenReturn(approximation); + when(scorer.twoPhaseIterator()).thenReturn(twoPhase); + + // Create wrapper + DisiWrapper wrapper = new DisiWrapper(scorer); + + // Create weight mock + Weight weight = mock(Weight.class); + + // Create HybridQueryScorer + HybridQueryScorer hybridScorer = new HybridQueryScorer(weight, Collections.singletonList(scorer)); + + // Test score method + float score = hybridScorer.score(wrapper); + + // Verify + assertEquals("Score should be 0.0 for NO_MORE_DOCS", 0.0f, score, 0.0001f); + } + + @SneakyThrows + public void testGetSubMatches_whenNoScorers_thenReturnNull() { + Weight weight = mock(Weight.class); + + // Create a scorer with a two-phase iterator that doesn't match + Scorer scorer = mock(Scorer.class); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + TwoPhaseIterator twoPhase = mock(TwoPhaseIterator.class); + when(twoPhase.matches()).thenReturn(false); + when(scorer.twoPhaseIterator()).thenReturn(twoPhase); + when(scorer.iterator()).thenReturn(iterator); + when(scorer.docID()).thenReturn(0); // Set a valid docID + + HybridQueryScorer hybridScorer = new HybridQueryScorer(weight, Collections.singletonList(scorer), ScoreMode.TOP_SCORES); + + DisiWrapper result = hybridScorer.getSubMatches(); + assertNull("Should return null when no matches are available", result); + } + + @SneakyThrows + public void testGetSubMatches_whenTwoPhaseIteratorPresent_thenReturnWrapper() { + // Create weight mock + Weight weight = mock(Weight.class); + + // Create scorer with iterator + Scorer scorer = mock(Scorer.class); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + + // Setup iterator behavior with AtomicInteger for state tracking + AtomicInteger currentDoc = new AtomicInteger(-1); + + when(iterator.docID()).thenAnswer(inv -> currentDoc.get()); + when(iterator.cost()).thenReturn(1L); + when(iterator.nextDoc()).thenAnswer(inv -> { + if (currentDoc.get() == -1) { + currentDoc.set(0); + return 0; + } + return DocIdSetIterator.NO_MORE_DOCS; + }); + + when(scorer.iterator()).thenReturn(iterator); + when(scorer.docID()).thenAnswer(inv -> currentDoc.get()); + + // Create and setup TwoPhaseIterator + TwoPhaseIterator twoPhase = mock(TwoPhaseIterator.class); + DocIdSetIterator approximation = mock(DocIdSetIterator.class); + + // Setup approximation behavior + when(approximation.docID()).thenAnswer(inv -> currentDoc.get()); + when(approximation.cost()).thenReturn(1L); + when(approximation.nextDoc()).thenAnswer(inv -> { + if (currentDoc.get() == -1) { + currentDoc.set(0); + return 0; + } + return DocIdSetIterator.NO_MORE_DOCS; + }); + + when(twoPhase.approximation()).thenReturn(approximation); + when(scorer.twoPhaseIterator()).thenReturn(twoPhase); + when(twoPhase.matches()).thenReturn(true); + + // Create HybridQueryScorer + HybridQueryScorer hybridScorer = new HybridQueryScorer(weight, Collections.singletonList(scorer), ScoreMode.TOP_SCORES); + + // Initialize the scorer by moving to first doc + DocIdSetIterator scorerIterator = hybridScorer.iterator(); + int firstDoc = scorerIterator.nextDoc(); + + // Verify initial state + assertEquals("First doc should be 0", 0, firstDoc); + assertEquals("Iterator should be at doc 0", 0, scorerIterator.docID()); + + // Get submatches + DisiWrapper result = hybridScorer.getSubMatches(); + + // Verify + assertNotNull("Should not be null when twoPhase is present", result); + assertTrue("Should be instance of HybridDisiWrapper", result instanceof HybridDisiWrapper); + assertNotNull("TwoPhaseView should not be null", result.twoPhaseView); + assertEquals("Should be at doc 0", 0, result.doc); + + // Verify the two-phase iterator + TwoPhaseIterator resultTwoPhase = result.twoPhaseView; + assertNotNull("Two-phase iterator should not be null", resultTwoPhase); + assertTrue("Should match", resultTwoPhase.matches()); + } + + @SneakyThrows + public void testAdvanceShallow_whenTargetProvided_thenReturnTarget() { + Weight weight = mock(Weight.class); + + // Create scorer + Scorer scorer = mock(Scorer.class); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + when(scorer.iterator()).thenReturn(iterator); + + // Create and setup TwoPhaseIterator + TwoPhaseIterator twoPhase = mock(TwoPhaseIterator.class); + DocIdSetIterator approximation = mock(DocIdSetIterator.class); + when(twoPhase.approximation()).thenReturn(approximation); + when(scorer.twoPhaseIterator()).thenReturn(twoPhase); + when(twoPhase.matches()).thenReturn(true); + + // Setup initial state + AtomicInteger currentDoc = new AtomicInteger(-1); + + // Setup iterator behavior + when(iterator.docID()).thenAnswer(inv -> currentDoc.get()); + when(approximation.docID()).thenAnswer(inv -> currentDoc.get()); + + // Setup nextDoc behavior + when(iterator.nextDoc()).thenAnswer(inv -> { + currentDoc.set(0); + return 0; + }); + + when(approximation.nextDoc()).thenAnswer(inv -> { + currentDoc.set(0); + return 0; + }); + + // Setup advance behavior + int target = 5; + when(approximation.advance(target)).thenAnswer(inv -> { + currentDoc.set(target); + return target; + }); + + when(iterator.advance(target)).thenAnswer(inv -> { + currentDoc.set(target); + return target; + }); + + // Setup costs + when(iterator.cost()).thenReturn(1L); + when(approximation.cost()).thenReturn(1L); + + // Create hybrid scorer with custom advanceShallow implementation + HybridQueryScorer hybridScorer = new HybridQueryScorer(weight, Collections.singletonList(scorer), ScoreMode.TOP_SCORES) { + @Override + public float score() throws IOException { + return 1.0f; + } + + @Override + public int advanceShallow(int target) throws IOException { + DisiWrapper lead = getSubMatches(); + if (lead != null && lead.twoPhaseView != null) { + DocIdSetIterator approx = lead.twoPhaseView.approximation(); + int result = approx.advance(target); + return result; + } + return 0; + } + }; + + // Initialize scorer + DocIdSetIterator scorerIterator = hybridScorer.iterator(); + + // Move to first doc + int firstDoc = scorerIterator.nextDoc(); + assertEquals("Should be at first doc", 0, scorerIterator.docID()); + + // Test advanceShallow + int result = hybridScorer.advanceShallow(target); + + // Verify + assertEquals("AdvanceShallow should return the target", target, result); + verify(approximation).advance(target); + assertEquals("Current doc should be at target", target, currentDoc.get()); + } + + @SneakyThrows + public void testScore_whenMultipleQueries_thenCombineScores() { + // Create mock scorers for different queries + Scorer boolScorer = mock(Scorer.class); + DocIdSetIterator boolIterator = mock(DocIdSetIterator.class); + when(boolScorer.iterator()).thenReturn(boolIterator); + when(boolScorer.docID()).thenReturn(1); + when(boolScorer.score()).thenReturn(0.7f); + + Scorer neuralScorer = mock(Scorer.class); + DocIdSetIterator neuralIterator = mock(DocIdSetIterator.class); + when(neuralScorer.iterator()).thenReturn(neuralIterator); + when(neuralScorer.docID()).thenReturn(1); + when(neuralScorer.score()).thenReturn(0.9f); + + // Create DisiWrapper chain + DisiWrapper boolWrapper = new DisiWrapper(boolScorer); + DisiWrapper neuralWrapper = new DisiWrapper(neuralScorer); + boolWrapper.next = neuralWrapper; + + Weight weight = mock(Weight.class); + HybridQueryScorer hybridScorer = new HybridQueryScorer(weight, Arrays.asList(boolScorer, neuralScorer), ScoreMode.COMPLETE); + float combinedScore = hybridScorer.score(boolWrapper); + + assertEquals("Combined score should be sum of bool and neural scores", 1.6f, combinedScore, 0.0001f); + } + + @SneakyThrows + public void testScore_whenEmptySubScorers_thenReturnZero() { + Weight weight = mock(Weight.class); + HybridQueryScorer hybridScorer = new HybridQueryScorer(weight, Collections.emptyList()); + float score = hybridScorer.score(null); + + assertEquals("Score should be 0.0 for null wrapper", 0.0f, score, 0.0001f); + } + + @SneakyThrows + public void testInitialization_whenValidScorer_thenSuccessful() { + // Create scorer with iterator + Scorer scorer = mock(Scorer.class); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + + // Setup state tracking + AtomicInteger currentDoc = new AtomicInteger(-1); + + // Setup iterator behavior + when(iterator.docID()).thenAnswer(inv -> currentDoc.get()); + when(iterator.cost()).thenReturn(1L); + when(iterator.nextDoc()).thenAnswer(inv -> { + if (currentDoc.get() == -1) { + currentDoc.set(0); + return 0; + } + return DocIdSetIterator.NO_MORE_DOCS; + }); + + when(scorer.iterator()).thenReturn(iterator); + when(scorer.docID()).thenAnswer(inv -> currentDoc.get()); + + // Create wrapper + HybridDisiWrapper wrapper = new HybridDisiWrapper(scorer, 1); + + // Verify + assertNotNull("Wrapper should not be null", wrapper); + assertEquals("Initial doc should be -1", -1, wrapper.doc); + assertNotNull("Iterator should not be null", wrapper.iterator); + assertEquals("Cost should be 1", 1L, wrapper.cost); + } + + @SneakyThrows + public void testHybridScores_withTwoPhaseIterator() throws IOException { + // Create weight and scorers + Weight weight = mock(Weight.class); + Scorer scorer1 = mock(Scorer.class); + TwoPhaseIterator twoPhaseIterator = mock(TwoPhaseIterator.class); + DocIdSetIterator approximation = mock(DocIdSetIterator.class); + + // Setup two-phase behavior + when(scorer1.twoPhaseIterator()).thenReturn(twoPhaseIterator); + when(twoPhaseIterator.approximation()).thenReturn(approximation); + when(scorer1.iterator()).thenReturn(approximation); + when(approximation.cost()).thenReturn(1L); + + // Setup DocIdSetIterator behavior - use different docIDs + when(approximation.docID()).thenReturn(5); // approximation at doc 5 + when(scorer1.docID()).thenReturn(5); // scorer at same doc + when(scorer1.score()).thenReturn(2.0f); + + // matches() always returns false - document should never match + when(twoPhaseIterator.matches()).thenReturn(false); + + // Create HybridQueryScorer with two-phase iterator + List subScorers = Collections.singletonList(scorer1); + HybridQueryScorer hybridScorer = new HybridQueryScorer(weight, subScorers); + + // Call matches() first to establish non-matching state + TwoPhaseIterator hybridTwoPhase = hybridScorer.twoPhaseIterator(); + assertNotNull("Should have two phase iterator", hybridTwoPhase); + assertFalse("Document should not match", hybridTwoPhase.matches()); + + // Get scores - should be zero since document doesn't match + float[] scores = hybridScorer.hybridScores(); + assertEquals("Should have one score entry", 1, scores.length); + assertEquals("Score should be 0 for non-matching document", 0.0f, scores[0], 0.001f); + + // Verify score() was never called since document didn't match + verify(scorer1, never()).score(); + verify(twoPhaseIterator, times(1)).matches(); + } + + @SneakyThrows + public void testTwoPhaseIterator_withNestedTwoPhaseQuery() { + // Create a scorer that uses two-phase iteration + Scorer scorer = mock(Scorer.class); + TwoPhaseIterator twoPhaseIterator = mock(TwoPhaseIterator.class); + DocIdSetIterator approximation = mock(DocIdSetIterator.class); + + // Setup the two-phase behavior + when(scorer.twoPhaseIterator()).thenReturn(twoPhaseIterator); + when(twoPhaseIterator.approximation()).thenReturn(approximation); + when(twoPhaseIterator.matches()).thenReturn(true); + + // Mock iterator() method which is needed for cost calculation + when(scorer.iterator()).thenReturn(approximation); + // Mock cost to avoid NPE + when(approximation.cost()).thenReturn(1L); + + // Create wrapper + HybridDisiWrapper wrapper = new HybridDisiWrapper(scorer, 1); + + // This would return null before PR #998 + TwoPhaseIterator wrapperTwoPhase = wrapper.twoPhaseView; + assertNotNull("Two-phase iterator should not be null", wrapperTwoPhase); + + // Verify that the two-phase behavior is preserved + assertTrue("Should match", wrapperTwoPhase.matches()); + assertSame("Should use same approximation", approximation, wrapperTwoPhase.approximation()); + } + protected static Scorer scorerWithTwoPhaseIterator(final int[] docs, final float[] scores, Weight weight, int maxDoc) { final DocIdSetIterator iterator = DocIdSetIterator.all(maxDoc); return new Scorer(weight) {