Skip to content

Commit

Permalink
Throw exception when results of fetch and query phases are different
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Dec 29, 2023
1 parent 6801844 commit 7d6dc4c
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
Expand All @@ -20,8 +18,6 @@
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.internal.SearchContext;
Expand Down Expand Up @@ -102,16 +98,7 @@ private <Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPha
}

QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer) searchPhaseResult;
if (queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery)) {
return true;
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);
if (shouldSkipProcessorDueToIncompatibleQueryAndFetchResults(querySearchResults, fetchSearchResult)) {
log.debug("Query and fetch results do not match, normalization processor is skipped");
return true;
}
return false;
return queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery);
}

/**
Expand Down Expand Up @@ -144,41 +131,4 @@ private <Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchS
Optional<Result> optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst();
return optionalFirstSearchPhaseResult.map(SearchPhaseResult::fetchResult);
}

private boolean shouldSkipProcessorDueToIncompatibleQueryAndFetchResults(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional
) {
if (fetchSearchResultOptional.isEmpty()) {
return false;
}
final List<Integer> docIds = unprocessedDocIds(querySearchResults);
SearchHits searchHits = fetchSearchResultOptional.get().hits();
SearchHit[] searchHitArray = searchHits.getHits();
// validate the both collections are of the same size
if (Objects.isNull(searchHitArray)) {
log.info("array of search hits in fetch phase results is null");
return true;
}
if (searchHitArray.length != docIds.size()) {
log.info(
String.format(
Locale.ROOT,
"number of documents in fetch results [%d] and query results [%d] is different",
searchHitArray.length,
docIds.size()
)
);
return true;
}
return false;
}

private List<Integer> unprocessedDocIds(final List<QuerySearchResult> querySearchResults) {
return querySearchResults.isEmpty()
? List.of()
: Arrays.stream(querySearchResults.get(0).topDocs().topDocs.scoreDocs)
.map(scoreDoc -> scoreDoc.doc)
.collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ private void updateOriginalFetchResults(
// 3. update original scores to normalized and combined values
// 4. order scores based on normalized and combined values
FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get();
SearchHit[] searchHitArray = getSearchHits(fetchSearchResult);
SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult);

// create map of docId to index of search hits. This solves (2), duplicates are from
// delimiter and start/stop elements, they all have same valid doc_id. For this map
Expand Down Expand Up @@ -169,9 +169,21 @@ private void updateOriginalFetchResults(
fetchSearchResult.hits(updatedSearchHits);
}

private SearchHit[] getSearchHits(final FetchSearchResult fetchSearchResult) {
private SearchHit[] getSearchHits(final List<Integer> docIds, final FetchSearchResult fetchSearchResult) {
SearchHits searchHits = fetchSearchResult.hits();
return searchHits.getHits();
SearchHit[] searchHitArray = searchHits.getHits();
// validate the both collections are of the same size
if (Objects.isNull(searchHitArray)) {
throw new IllegalStateException(

Check warning on line 177 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java#L177

Added line #L177 was not covered by tests
"score normalization processor cannot produce final query result, fetch query phase returns empty results"
);
}
if (searchHitArray.length != docIds.size()) {
throw new IllegalStateException(
"score normalization processor cannot produce final query result, the number of documents returned by fetch and query phases does not match"
);
}
return searchHitArray;
}

private List<Integer> unprocessedDocIds(final List<QuerySearchResult> querySearchResults) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.neuralsearch.processor;

import static org.hamcrest.Matchers.startsWith;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
Expand Down Expand Up @@ -417,7 +418,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz
verify(normalizationProcessorWorkflow).execute(any(), any(), any(), any());
}

public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenSkipNormalization() {
public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() {
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);
Expand Down Expand Up @@ -489,15 +490,13 @@ public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenSkipNor
queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown);

SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class);
normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext);

List<QuerySearchResult> querySearchResults = queryPhaseResultConsumer.getAtomicArray()
.asList()
.stream()
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());

assertNotNull(querySearchResults);
verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any());
IllegalStateException exception = expectThrows(
IllegalStateException.class,
() -> normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext)
);
org.hamcrest.MatcherAssert.assertThat(
exception.getMessage(),
startsWith("score normalization processor cannot produce final query result")
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom
TestUtils.assertFetchResultScores(fetchSearchResult, 4);
}

public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenSuccess() {
public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() {
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);
Expand Down Expand Up @@ -282,12 +282,14 @@ public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_then
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10);
fetchSearchResult.hits(searchHits);

normalizationProcessorWorkflow.execute(
querySearchResults,
Optional.of(fetchSearchResult),
ScoreNormalizationFactory.DEFAULT_METHOD,
ScoreCombinationFactory.DEFAULT_METHOD
expectThrows(
IllegalStateException.class,
() -> normalizationProcessorWorkflow.execute(
querySearchResults,
Optional.of(fetchSearchResult),
ScoreNormalizationFactory.DEFAULT_METHOD,
ScoreCombinationFactory.DEFAULT_METHOD
)
);
TestUtils.assertQueryResultScores(querySearchResults);
}
}

0 comments on commit 7d6dc4c

Please sign in to comment.