From 585fbbe0da23bdb827bf012a85d6b2dc9af78c11 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 29 Dec 2023 14:20:00 -0800 Subject: [PATCH] Fixing multiple issues reported in #497 (#524) * Allow multiple identical sub-queries in hybrid query, removed validation for total hits Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + .../NormalizationProcessorWorkflow.java | 16 +- .../neuralsearch/query/HybridQueryScorer.java | 39 +++- .../search/HitsThresholdChecker.java | 3 - .../NormalizationProcessorTests.java | 174 ++++++++++++++++++ .../neuralsearch/query/HybridQueryIT.java | 69 +++++++ .../query/HybridQueryWeightTests.java | 61 +++++- ...ts.java => HitsThresholdCheckerTests.java} | 16 +- .../HybridTopScoreDocCollectorTests.java | 51 +++++ 9 files changed, 411 insertions(+), 19 deletions(-) rename src/test/java/org/opensearch/neuralsearch/search/{HitsTresholdCheckerTests.java => HitsThresholdCheckerTests.java} (53%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f02272c4..c9896ff13 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes +- Fixing multiple issues reported in #497 ([#524](https://github.com/opensearch-project/neural-search/pull/524)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index b8bc86de5..c322102d5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -173,8 +173,20 @@ private SearchHit[] getSearchHits(final List docIds, final FetchSearchR SearchHits searchHits = fetchSearchResult.hits(); SearchHit[] searchHitArray = searchHits.getHits(); // validate the both collections are of the same size - if (Objects.isNull(searchHitArray) || searchHitArray.length != docIds.size()) { - throw new IllegalStateException("Score normalization processor cannot produce final query result"); + if (Objects.isNull(searchHitArray)) { + throw new IllegalStateException( + "score normalization processor cannot produce final query result, fetch query phase returns empty results" + ); + } + if (searchHitArray.length != docIds.size()) { + throw new IllegalStateException( + String.format( + Locale.ROOT, + "score normalization processor cannot produce final query result, the number of documents after fetch phase [%d] is different from number of documents from query phase [%d]", + searchHitArray.length, + docIds.size() + ) + ); } return searchHitArray; } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index 109a50d05..e3e6a0862 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -6,9 +6,11 @@ package org.opensearch.neuralsearch.query; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Objects; @@ -37,7 +39,7 @@ public final class HybridQueryScorer extends Scorer { private final float[] subScores; - private final Map queryToIndex; + private final Map> queryToIndex; public HybridQueryScorer(Weight weight, List subScorers) throws IOException { super(weight); @@ -111,24 +113,43 @@ public float[] hybridScores() throws IOException { DisiWrapper topList = subScorersPQ.topList(); for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) { // check if this doc has match in the subQuery. If not, add score as 0.0 and continue - if (disiWrapper.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) { + Scorer scorer = disiWrapper.scorer; + if (scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) { continue; } - float subScore = disiWrapper.scorer.score(); - scores[queryToIndex.get(disiWrapper.scorer.getWeight().getQuery())] = subScore; + Query query = scorer.getWeight().getQuery(); + List indexes = queryToIndex.get(query); + // we need to find the index of first sub-query that hasn't been set yet. Such score will have initial value of "0.0" + int index = indexes.stream() + .mapToInt(idx -> idx) + .filter(idx -> Float.compare(scores[idx], 0.0f) == 0) + .findFirst() + .orElseThrow( + () -> new IllegalStateException( + String.format( + Locale.ROOT, + "cannot set score for one of hybrid search subquery [%s] and document [%d]", + query.toString(), + scorer.docID() + ) + ) + ); + scores[index] = scorer.score(); } return scores; } - private Map mapQueryToIndex() { - Map queryToIndex = new HashMap<>(); + private Map> mapQueryToIndex() { + Map> queryToIndex = new HashMap<>(); int idx = 0; for (Scorer scorer : subScorers) { if (scorer == null) { idx++; continue; } - queryToIndex.put(scorer.getWeight().getQuery(), idx); + Query query = scorer.getWeight().getQuery(); + queryToIndex.putIfAbsent(query, new ArrayList<>()); + queryToIndex.get(query).add(idx); idx++; } return queryToIndex; @@ -137,7 +158,9 @@ private Map mapQueryToIndex() { private DisiPriorityQueue initializeSubScorersPQ() { Objects.requireNonNull(queryToIndex, "should not be null"); Objects.requireNonNull(subScorers, "should not be null"); - DisiPriorityQueue subScorersPQ = new DisiPriorityQueue(queryToIndex.size()); + // we need to count this way in order to include all identical sub-queries + int numOfSubQueries = queryToIndex.values().stream().map(List::size).reduce(0, Integer::sum); + DisiPriorityQueue subScorersPQ = new DisiPriorityQueue(numOfSubQueries); for (Scorer scorer : subScorers) { if (scorer == null) { continue; diff --git a/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java b/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java index c8c52320f..dea9c6bae 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java +++ b/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java @@ -24,9 +24,6 @@ public HitsThresholdChecker(int totalHitsThreshold) { if (totalHitsThreshold < 0) { throw new IllegalArgumentException(String.format(Locale.ROOT, "totalHitsThreshold must be >= 0, got %d", totalHitsThreshold)); } - if (totalHitsThreshold == Integer.MAX_VALUE) { - throw new IllegalArgumentException(String.format(Locale.ROOT, "totalHitsThreshold must be less than max integer value")); - } this.totalHitsThreshold = totalHitsThreshold; } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index 41348ec49..26d9fc808 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.processor; +import static org.hamcrest.Matchers.startsWith; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -15,6 +16,7 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; import java.util.List; +import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -45,9 +47,13 @@ import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; import org.opensearch.search.SearchShardTarget; import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.pipeline.PipelineAggregator; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.fetch.QueryFetchSearchResult; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; @@ -325,4 +331,172 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); } + + public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormalization() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + NormalizationProcessor normalizationProcessor = new NormalizationProcessor( + PROCESSOR_TAG, + DESCRIPTION, + new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD), + new ScoreCombinationFactory().createCombination(COMBINATION_METHOD), + normalizationProcessorWorkflow + ); + + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.setBatchedReduceSize(4); + AtomicReference onPartialMergeFailure = new AtomicReference<>(); + QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( + searchRequest, + executor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + searchPhaseController, + SearchProgressListener.NOOP, + writableRegistry(), + 10, + e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { + curr.addSuppressed(prev); + return curr; + }) + ); + CountDownLatch partialReduceLatch = new CountDownLatch(5); + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(4), + createDelimiterElementForHybridSearchResults(4), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + new ScoreDoc(4, 0.25f), + new ScoreDoc(10, 0.2f), + createStartStopElementForHybridSearchResults(4) } + + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + + FetchSearchResult fetchSearchResult = new FetchSearchResult(); + fetchSearchResult.setShardIndex(shardId); + fetchSearchResult.setSearchShardTarget(searchShardTarget); + SearchHit[] searchHitArray = new SearchHit[] { + new SearchHit(4, "2", Map.of(), Map.of()), + new SearchHit(4, "2", Map.of(), Map.of()), + new SearchHit(0, "10", Map.of(), Map.of()), + new SearchHit(2, "1", Map.of(), Map.of()), + new SearchHit(4, "2", Map.of(), Map.of()), + new SearchHit(10, "3", Map.of(), Map.of()), + new SearchHit(4, "2", Map.of(), Map.of()) }; + SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); + fetchSearchResult.hits(searchHits); + + QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult); + queryFetchSearchResult.setShardIndex(shardId); + + queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); + + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); + + List querySearchResults = queryPhaseResultConsumer.getAtomicArray() + .asList() + .stream() + .map(result -> result == null ? null : result.queryResult()) + .collect(Collectors.toList()); + + TestUtils.assertQueryResultScores(querySearchResults); + verify(normalizationProcessorWorkflow).execute(any(), any(), any(), any()); + } + + public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + NormalizationProcessor normalizationProcessor = new NormalizationProcessor( + PROCESSOR_TAG, + DESCRIPTION, + new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD), + new ScoreCombinationFactory().createCombination(COMBINATION_METHOD), + normalizationProcessorWorkflow + ); + + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.setBatchedReduceSize(4); + AtomicReference onPartialMergeFailure = new AtomicReference<>(); + QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( + searchRequest, + executor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + searchPhaseController, + SearchProgressListener.NOOP, + writableRegistry(), + 10, + e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { + curr.addSuppressed(prev); + return curr; + }) + ); + CountDownLatch partialReduceLatch = new CountDownLatch(5); + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(4), + createDelimiterElementForHybridSearchResults(4), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + new ScoreDoc(4, 0.25f), + new ScoreDoc(10, 0.2f), + createStartStopElementForHybridSearchResults(4) } + + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + + FetchSearchResult fetchSearchResult = new FetchSearchResult(); + fetchSearchResult.setShardIndex(shardId); + fetchSearchResult.setSearchShardTarget(searchShardTarget); + SearchHit[] searchHitArray = new SearchHit[] { + new SearchHit(0, "10", Map.of(), Map.of()), + new SearchHit(2, "1", Map.of(), Map.of()), + new SearchHit(4, "2", Map.of(), Map.of()), + new SearchHit(10, "3", Map.of(), Map.of()), + new SearchHit(0, "10", Map.of(), Map.of()), }; + SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(5, TotalHits.Relation.EQUAL_TO), 10); + fetchSearchResult.hits(searchHits); + + QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult); + queryFetchSearchResult.setShardIndex(shardId); + + queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); + + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + IllegalStateException exception = expectThrows( + IllegalStateException.class, + () -> normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext) + ); + org.hamcrest.MatcherAssert.assertThat( + exception.getMessage(), + startsWith("score normalization processor cannot produce final query result") + ); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 312cb8b3a..864ebdc68 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -174,6 +174,75 @@ public void testComplexQuery_whenMultipleSubqueries_thenSuccessful() { assertEquals(RELATION_EQUAL_TO, total.get("relation")); } + /** + * Tests complex query with multiple nested sub-queries, where some sub-queries are same + * { + * "query": { + * "hybrid": { + * "queries": [ + * { + * "term": { + * "text": "word1" + * } + * }, + * { + * "term": { + * "text": "word2" + * } + * }, + * { + * "term": { + * "text": "word3" + * } + * } + * ] + * } + * } + * } + */ + @SneakyThrows + public void testComplexQuery_whenMultipleIdenticalSubQueries_thenSuccessful() { + initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); + TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + + HybridQueryBuilder hybridQueryBuilderThreeTerms = new HybridQueryBuilder(); + hybridQueryBuilderThreeTerms.add(termQueryBuilder1); + hybridQueryBuilderThreeTerms.add(termQueryBuilder2); + hybridQueryBuilderThreeTerms.add(termQueryBuilder3); + + Map searchResponseAsMap1 = search( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + hybridQueryBuilderThreeTerms, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertEquals(2, getHitCount(searchResponseAsMap1)); + + List> hits1NestedList = getNestedHits(searchResponseAsMap1); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hits1NestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap1); + assertNotNull(total.get("value")); + assertEquals(2, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + @SneakyThrows public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java index 50198af46..0b9af2bcd 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java @@ -38,7 +38,10 @@ public class HybridQueryWeightTests extends OpenSearchQueryTestCase { - static final String TERM_QUERY_TEXT = "keyword"; + private static final String TERM_QUERY_TEXT = "keyword"; + private static final String RANGE_FIELD = "date _range"; + private static final String FROM_TEXT = "123"; + private static final String TO_TEXT = "456"; @SneakyThrows public void testScorerIterator_whenExecuteQuery_thenScorerIteratorSuccessful() { @@ -87,6 +90,62 @@ public void testScorerIterator_whenExecuteQuery_thenScorerIteratorSuccessful() { directory.close(); } + @SneakyThrows + public void testSubQueries_whenMultipleEqualSubQueries_thenSuccessful() { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId, TERM_QUERY_TEXT, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of( + QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext), + QueryBuilders.rangeQuery(RANGE_FIELD) + .from(FROM_TEXT) + .to(TO_TEXT) + .rewrite(mockQueryShardContext) + .rewrite(mockQueryShardContext) + .toQuery(mockQueryShardContext), + QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext) + ) + ); + IndexSearcher searcher = newSearcher(reader); + Weight weight = hybridQueryWithTerm.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f); + + assertNotNull(weight); + + LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); + Scorer scorer = weight.scorer(leafReaderContext); + + assertNotNull(scorer); + + DocIdSetIterator iterator = scorer.iterator(); + int actualDoc = iterator.nextDoc(); + int actualDocId = Integer.parseInt(reader.document(actualDoc).getField("id").stringValue()); + + assertEquals(docId, actualDocId); + + assertTrue(weight.isCacheable(leafReaderContext)); + + Matches matches = weight.matches(leafReaderContext, actualDoc); + MatchesIterator matchesIterator = matches.getMatches(TEXT_FIELD_NAME); + assertTrue(matchesIterator.next()); + + w.close(); + reader.close(); + directory.close(); + } + @SneakyThrows public void testExplain_whenCallExplain_thenFail() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); diff --git a/src/test/java/org/opensearch/neuralsearch/search/HitsTresholdCheckerTests.java b/src/test/java/org/opensearch/neuralsearch/search/HitsThresholdCheckerTests.java similarity index 53% rename from src/test/java/org/opensearch/neuralsearch/search/HitsTresholdCheckerTests.java rename to src/test/java/org/opensearch/neuralsearch/search/HitsThresholdCheckerTests.java index 0a6a12c88..3ce9e3dfe 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/HitsTresholdCheckerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/HitsThresholdCheckerTests.java @@ -10,9 +10,9 @@ import org.apache.lucene.search.ScoreMode; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; -public class HitsTresholdCheckerTests extends OpenSearchQueryTestCase { +public class HitsThresholdCheckerTests extends OpenSearchQueryTestCase { - public void testTresholdReached_whenIncrementCount_thenTresholdReached() { + public void testThresholdReached_whenIncrementCount_thenThresholdReached() { HitsThresholdChecker hitsThresholdChecker = new HitsThresholdChecker(5); assertEquals(5, hitsThresholdChecker.getTotalHitsThreshold()); assertEquals(ScoreMode.TOP_SCORES, hitsThresholdChecker.scoreMode()); @@ -23,11 +23,17 @@ public void testTresholdReached_whenIncrementCount_thenTresholdReached() { assertTrue(hitsThresholdChecker.isThresholdReached()); } - public void testTresholdLimit_whenThresholdNegative_thenFail() { + public void testThresholdLimit_whenThresholdNegative_thenFail() { expectThrows(IllegalArgumentException.class, () -> new HitsThresholdChecker(-1)); } - public void testTresholdLimit_whenThresholdMaxValue_thenFail() { - expectThrows(IllegalArgumentException.class, () -> new HitsThresholdChecker(Integer.MAX_VALUE)); + public void testTrackThreshold_whenTrackThresholdSet_thenSuccessful() { + HitsThresholdChecker hitsThresholdChecker = new HitsThresholdChecker(Integer.MAX_VALUE); + assertEquals(ScoreMode.TOP_SCORES, hitsThresholdChecker.scoreMode()); + assertFalse(hitsThresholdChecker.isThresholdReached()); + hitsThresholdChecker.incrementHitCount(); + assertFalse(hitsThresholdChecker.isThresholdReached()); + IntStream.rangeClosed(1, 5).forEach((checker) -> hitsThresholdChecker.incrementHitCount()); + assertFalse(hitsThresholdChecker.isThresholdReached()); } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java index 06bbfc416..72cf8be49 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java @@ -349,4 +349,55 @@ public void testTopDocs_whenMatchedDocsDifferentForEachSubQuery_thenSuccessful() reader.close(); directory.close(); } + + @SneakyThrows + public void testTrackTotalHits_whenTotalHitsSetIntegerMaxValue_thenSuccessful() { + final Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_1, FIELD_1_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_2, FIELD_2_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_3, FIELD_3_VALUE, ft)); + w.commit(); + + DirectoryReader reader = DirectoryReader.open(w); + + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + + HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector( + NUM_DOCS, + new HitsThresholdChecker(Integer.MAX_VALUE) + ); + LeafCollector leafCollector = hybridTopScoreDocCollector.getLeafCollector(leafReaderContext); + assertNotNull(leafCollector); + + Weight weight = mock(Weight.class); + int[] docIds = new int[] { DOC_ID_1, DOC_ID_2, DOC_ID_3 }; + Arrays.sort(docIds); + final List scores = Stream.generate(() -> random().nextFloat()).limit(NUM_DOCS).collect(Collectors.toList()); + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer( + weight, + Arrays.asList(scorer(docIds, scores, fakeWeight(new MatchAllDocsQuery()))) + ); + + leafCollector.setScorer(hybridQueryScorer); + List hybridScores = new ArrayList<>(); + DocIdSetIterator iterator = hybridQueryScorer.iterator(); + int nextDoc = iterator.nextDoc(); + while (nextDoc != NO_MORE_DOCS) { + hybridScores.add(hybridQueryScorer.hybridScores()); + nextDoc = iterator.nextDoc(); + } + // assert + assertEquals(3, hybridScores.size()); + assertFalse(hybridScores.stream().anyMatch(score -> score[0] <= 0.0)); + + w.close(); + reader.close(); + directory.close(); + } }