Skip to content

Commit

Permalink
Integ Test and unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Jain <[email protected]>
  • Loading branch information
vibrantvarun committed Oct 23, 2024
1 parent f1c64dc commit 152aa4c
Show file tree
Hide file tree
Showing 14 changed files with 530 additions and 85 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.17...2.x)
### Features
- Pagination in Hybrid query ([]())
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,21 @@ public <Result extends SearchPhaseResult> void process(
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);
normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique);
int fromValueForSingleShard = 0;
boolean isSingleShard = false;
if (searchPhaseContext.getNumShards() == 1 && fetchSearchResult.isPresent()) {
isSingleShard = true;
fromValueForSingleShard = searchPhaseContext.getRequest().source().from();
}

normalizationWorkflow.execute(
querySearchResults,
fetchSearchResult,
normalizationTechnique,
combinationTechnique,
fromValueForSingleShard,
isSingleShard
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ public void execute(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
final ScoreCombinationTechnique combinationTechnique,
final int fromValueForSingleShard,
final boolean isSingleShard
) {
// save original state
List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults);
Expand All @@ -73,6 +75,8 @@ public void execute(
.scoreCombinationTechnique(combinationTechnique)
.querySearchResults(querySearchResults)
.sort(evaluateSortCriteria(querySearchResults, queryTopDocs))
.fromValueForSingleShard(fromValueForSingleShard)
.isSingleShard(isSingleShard)
.build();

// combine
Expand All @@ -82,7 +86,7 @@ public void execute(
// post-process data
log.debug("Post-process query results after score normalization and combination");
updateOriginalQueryResults(combineScoresDTO);
updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds);
updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds, fromValueForSingleShard);
}

/**
Expand Down Expand Up @@ -123,10 +127,14 @@ private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO)
buildTopDocs(updatedTopDocs, sort),
maxScoreForShard(updatedTopDocs, sort != null)
);
if (combineScoresDTO.isSingleShard()) {
querySearchResult.from(combineScoresDTO.getFromValueForSingleShard());
}
querySearchResult.topDocs(updatedTopDocsAndMaxScore, querySearchResult.sortValueFormats());
}

if (from > 0 && from > totalScoreDocsCount) {
if ((from > 0 || combineScoresDTO.getFromValueForSingleShard() > 0)
&& (from > totalScoreDocsCount || combineScoresDTO.getFromValueForSingleShard() > totalScoreDocsCount)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results")
);
Expand Down Expand Up @@ -189,7 +197,8 @@ private TopDocs buildTopDocs(CompoundTopDocs updatedTopDocs, Sort sort) {
private void updateOriginalFetchResults(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final List<Integer> docIds
final List<Integer> docIds,
final int fromValueForSingleShard
) {
if (fetchSearchResultOptional.isEmpty()) {
return;
Expand Down Expand Up @@ -221,14 +230,26 @@ private void updateOriginalFetchResults(

QuerySearchResult querySearchResult = querySearchResults.get(0);
TopDocs topDocs = querySearchResult.topDocs().topDocs;

// iterate over the normalized/combined scores, that solves (1) and (3)
SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> {
SearchHit[] updatedSearchHitArray = new SearchHit[topDocs.scoreDocs.length - fromValueForSingleShard];
for (int i = fromValueForSingleShard; i < topDocs.scoreDocs.length; i++) {
ScoreDoc scoreDoc = topDocs.scoreDocs[i];
// get fetched hit content by doc_id
SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc);
// update score to normalized/combined value (3)
searchHit.score(scoreDoc.score);
return searchHit;
}).toArray(SearchHit[]::new);
updatedSearchHitArray[i - fromValueForSingleShard] = searchHit;
}

// iterate over the normalized/combined scores, that solves (1) and (3)
// SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> {
// // get fetched hit content by doc_id
// SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc);
// // update score to normalized/combined value (3)
// searchHit.score(scoreDoc.score);
// return searchHit;
// }).toArray(SearchHit[]::new);
SearchHits updatedSearchHits = new SearchHits(
updatedSearchHitArray,
querySearchResult.getTotalHits(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@ public class CombineScoresDto {
private List<QuerySearchResult> querySearchResults;
@Nullable
private Sort sort;
private int fromValueForSingleShard;
private boolean isSingleShard;
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

import java.util.Locale;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.Weight;
Expand Down Expand Up @@ -56,6 +58,7 @@
* In most cases it will be wrapped in MultiCollectorManager.
*/
@RequiredArgsConstructor
@Log4j2
public abstract class HybridCollectorManager implements CollectorManager<Collector, ReduceableSearchResult> {

private final int numHits;
Expand All @@ -68,6 +71,7 @@ public abstract class HybridCollectorManager implements CollectorManager<Collect
private final TopDocsMerger topDocsMerger;
@Nullable
private final FieldDoc after;
private static final int DEFAULT_PAGINATION_DEPTH = 10;

/**
* Create new instance of HybridCollectorManager depending on the concurrent search beeing enabled or disabled.
Expand All @@ -76,20 +80,20 @@ public abstract class HybridCollectorManager implements CollectorManager<Collect
* @throws IOException
*/
public static CollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException {
if (searchContext.scrollContext() != null) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Scroll operation is not supported in hybrid query"));
}
final IndexReader reader = searchContext.searcher().getIndexReader();
final int totalNumDocs = Math.max(0, reader.numDocs());
HybridQuery hybridQuery = (HybridQuery) searchContext.query();
int retrievalSize;
if (hybridQuery.getPaginationDepth() == 0) {
retrievalSize = searchContext.from() + searchContext.size();
} else {
retrievalSize = hybridQuery.getPaginationDepth();
}
int numDocs = Math.min(retrievalSize, totalNumDocs);
int numDocs = Math.min(getSubqueryResultsRetrievalSize(searchContext), totalNumDocs);
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();
if (searchContext.sort() != null) {
validateSortCriteria(searchContext, searchContext.trackScores());
}
boolean isSingleShard = searchContext.numberOfShards() == 1;
if (isSingleShard && searchContext.from() > 0) {
searchContext.from(0);
}

Weight filteringWeight = null;
// Check for post filter to create weight for filter query and later use that weight in the search workflow
Expand Down Expand Up @@ -412,6 +416,42 @@ private ReduceableSearchResult reduceSearchResults(final List<ReduceableSearchRe
};
}

/**
* Get maximum subquery results count to be collected from each shard.
* @param searchContext search context that contains pagination depth
* @return results size to collected
*/
private static int getSubqueryResultsRetrievalSize(final SearchContext searchContext) {
int paginationDepth;
HybridQuery hybridQuery;
Query query = searchContext.query();
if (query instanceof BooleanQuery) {
BooleanQuery booleanQuery = (BooleanQuery) query;
hybridQuery = (HybridQuery) booleanQuery.clauses().get(0).getQuery();
paginationDepth = hybridQuery.getPaginationDepth();
} else {
hybridQuery = (HybridQuery) query;
paginationDepth = hybridQuery.getPaginationDepth();
}

if (paginationDepth != 0) {
validatePaginationDepth(paginationDepth);
return paginationDepth;
} else if (searchContext.from() > 0 && paginationDepth == 0) {
return DEFAULT_PAGINATION_DEPTH;
} else {
return searchContext.from() + searchContext.size();
}
}

private static void validatePaginationDepth(int depth) {
if (depth < 0 || depth > 10000) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Pagination depth should lie in the range of 1-1000. Received: %s", depth)
);
}
}

/**
* Implementation of the HybridCollector that reuses instance of collector on each even call. This allows caller to
* use saved state of collector
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,6 @@ public boolean searchWith(
validateQuery(searchContext, query);
return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
} else {
// TODO remove this check after following issue https://github.com/opensearch-project/neural-search/issues/280 gets resolved.
// if (searchContext.from() != 0) {
// throw new IllegalArgumentException("In the current OpenSearch version pagination is not supported with hybrid query");
// }
Query hybridQuery = extractHybridQuery(searchContext, query);
QueryPhaseSearcher queryPhaseSearcher = getQueryPhaseSearcher(searchContext);
return queryPhaseSearcher.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout);
Expand Down Expand Up @@ -214,6 +210,5 @@ protected boolean searchWithCollector(
hasTimeout
);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import static org.hamcrest.Matchers.startsWith;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
Expand Down Expand Up @@ -272,7 +274,7 @@ public void testEmptySearchResults_whenEmptySearchResults_thenDoNotExecuteWorkfl
SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class);
normalizationProcessor.process(null, searchPhaseContext);

verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any());
verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any(), anyInt(), anyBoolean());
}

public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResult_thenDoNotExecuteWorkflow() {
Expand Down Expand Up @@ -328,7 +330,7 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul
when(searchPhaseContext.getNumShards()).thenReturn(numberOfShards);
normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext);

verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any());
verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any(), anyInt(), anyBoolean());
}

public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormalization() {
Expand Down Expand Up @@ -417,7 +419,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz
.collect(Collectors.toList());

TestUtils.assertQueryResultScores(querySearchResults);
verify(normalizationProcessorWorkflow).execute(any(), any(), any(), any());
verify(normalizationProcessorWorkflow).execute(any(), any(), any(), any(), anyInt(), anyBoolean());
}

public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() {
Expand Down
Loading

0 comments on commit 152aa4c

Please sign in to comment.