Skip to content

Commit

Permalink
Add check for fetch and query result sizes
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Dec 28, 2023
1 parent 71dc465 commit e404775
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
Expand All @@ -18,6 +19,8 @@
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.internal.SearchContext;
Expand Down Expand Up @@ -98,7 +101,16 @@ private <Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPha
}

QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer) searchPhaseResult;
return queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery);
if (queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery)) {
return true;
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);
if (shouldSkipProcessorDueToIncompatibleQueryAndFetchResults(querySearchResults, fetchSearchResult)) {
log.debug("Query and fetch results do not match, normalization processor is skipped");
return true;
}
return false;
}

/**
Expand Down Expand Up @@ -131,4 +143,29 @@ private <Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchS
Optional<Result> optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst();
return optionalFirstSearchPhaseResult.map(SearchPhaseResult::fetchResult);
}

private boolean shouldSkipProcessorDueToIncompatibleQueryAndFetchResults(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional
) {
if (fetchSearchResultOptional.isEmpty()) {
return false;
}
final List<Integer> docIds = unprocessedDocIds(querySearchResults);
SearchHits searchHits = fetchSearchResultOptional.get().hits();
SearchHit[] searchHitArray = searchHits.getHits();
// validate the both collections are of the same size
if (Objects.isNull(searchHitArray) || searchHitArray.length != docIds.size()) {
return true;
}
return false;
}

private List<Integer> unprocessedDocIds(final List<QuerySearchResult> querySearchResults) {
return querySearchResults.isEmpty()
? List.of()

Check warning on line 166 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java#L166

Added line #L166 was not covered by tests
: Arrays.stream(querySearchResults.get(0).topDocs().topDocs.scoreDocs)
.map(scoreDoc -> scoreDoc.doc)
.collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,15 @@ private SearchHit[] getSearchHits(final List<Integer> docIds, final FetchSearchR
SearchHits searchHits = fetchSearchResult.hits();
SearchHit[] searchHitArray = searchHits.getHits();
// validate the both collections are of the same size
if (Objects.isNull(searchHitArray) || searchHitArray.length != docIds.size()) {
throw new IllegalStateException("Score normalization processor cannot produce final query result");
if (Objects.isNull(searchHitArray)) {
throw new IllegalStateException(

Check warning on line 177 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java#L177

Added line #L177 was not covered by tests
"Score normalization processor cannot produce final query result, for one shard case fetch does not have any results"
);
}
if (searchHitArray.length != docIds.size()) {
throw new IllegalStateException(
"Score normalization processor cannot produce final query result, for one shard case number of fetched documents does not match number of search hits"
);
}
return searchHitArray;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults;

import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -45,9 +46,13 @@
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.pipeline.PipelineAggregator;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.fetch.QueryFetchSearchResult;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
Expand Down Expand Up @@ -325,4 +330,174 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul

verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any());
}

public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormalization() {
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);
NormalizationProcessor normalizationProcessor = new NormalizationProcessor(
PROCESSOR_TAG,
DESCRIPTION,
new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD),
new ScoreCombinationFactory().createCombination(COMBINATION_METHOD),
normalizationProcessorWorkflow
);

SearchRequest searchRequest = new SearchRequest(INDEX_NAME);
searchRequest.setBatchedReduceSize(4);
AtomicReference<Exception> onPartialMergeFailure = new AtomicReference<>();
QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer(
searchRequest,
executor,
new NoopCircuitBreaker(CircuitBreaker.REQUEST),
searchPhaseController,
SearchProgressListener.NOOP,
writableRegistry(),
10,
e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> {
curr.addSuppressed(prev);
return curr;
})
);
CountDownLatch partialReduceLatch = new CountDownLatch(5);
int shardId = 0;
SearchShardTarget searchShardTarget = new SearchShardTarget(
"node",
new ShardId("index", "uuid", shardId),
null,
OriginalIndices.NONE
);
QuerySearchResult querySearchResult = new QuerySearchResult();
TopDocs topDocs = new TopDocs(
new TotalHits(4, TotalHits.Relation.EQUAL_TO),

new ScoreDoc[] {
createStartStopElementForHybridSearchResults(4),
createDelimiterElementForHybridSearchResults(4),
new ScoreDoc(0, 0.5f),
new ScoreDoc(2, 0.3f),
new ScoreDoc(4, 0.25f),
new ScoreDoc(10, 0.2f),
createStartStopElementForHybridSearchResults(4) }

);
querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);

FetchSearchResult fetchSearchResult = new FetchSearchResult();
fetchSearchResult.setShardIndex(shardId);
fetchSearchResult.setSearchShardTarget(searchShardTarget);
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(4, "2", Map.of(), Map.of()),
new SearchHit(4, "2", Map.of(), Map.of()),
new SearchHit(0, "10", Map.of(), Map.of()),
new SearchHit(2, "1", Map.of(), Map.of()),
new SearchHit(4, "2", Map.of(), Map.of()),
new SearchHit(10, "3", Map.of(), Map.of()),
new SearchHit(4, "2", Map.of(), Map.of()) };
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10);
fetchSearchResult.hits(searchHits);

QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult);
queryFetchSearchResult.setShardIndex(shardId);

queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown);

SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class);
normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext);

List<QuerySearchResult> querySearchResults = queryPhaseResultConsumer.getAtomicArray()
.asList()
.stream()
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());

TestUtils.assertQueryResultScores(querySearchResults);
verify(normalizationProcessorWorkflow).execute(any(), any(), any(), any());
}

public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenSkipNormalization() {
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);
NormalizationProcessor normalizationProcessor = new NormalizationProcessor(
PROCESSOR_TAG,
DESCRIPTION,
new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD),
new ScoreCombinationFactory().createCombination(COMBINATION_METHOD),
normalizationProcessorWorkflow
);

SearchRequest searchRequest = new SearchRequest(INDEX_NAME);
searchRequest.setBatchedReduceSize(4);
AtomicReference<Exception> onPartialMergeFailure = new AtomicReference<>();
QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer(
searchRequest,
executor,
new NoopCircuitBreaker(CircuitBreaker.REQUEST),
searchPhaseController,
SearchProgressListener.NOOP,
writableRegistry(),
10,
e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> {
curr.addSuppressed(prev);
return curr;
})
);
CountDownLatch partialReduceLatch = new CountDownLatch(5);
int shardId = 0;
SearchShardTarget searchShardTarget = new SearchShardTarget(
"node",
new ShardId("index", "uuid", shardId),
null,
OriginalIndices.NONE
);
QuerySearchResult querySearchResult = new QuerySearchResult();
TopDocs topDocs = new TopDocs(
new TotalHits(4, TotalHits.Relation.EQUAL_TO),

new ScoreDoc[] {
createStartStopElementForHybridSearchResults(4),
createDelimiterElementForHybridSearchResults(4),
new ScoreDoc(0, 0.5f),
new ScoreDoc(2, 0.3f),
new ScoreDoc(4, 0.25f),
new ScoreDoc(10, 0.2f),
createStartStopElementForHybridSearchResults(4) }

);
querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);

FetchSearchResult fetchSearchResult = new FetchSearchResult();
fetchSearchResult.setShardIndex(shardId);
fetchSearchResult.setSearchShardTarget(searchShardTarget);
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(0, "10", Map.of(), Map.of()),
new SearchHit(2, "1", Map.of(), Map.of()),
new SearchHit(4, "2", Map.of(), Map.of()),
new SearchHit(10, "3", Map.of(), Map.of()),
new SearchHit(0, "10", Map.of(), Map.of()), };
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(5, TotalHits.Relation.EQUAL_TO), 10);
fetchSearchResult.hits(searchHits);

QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult);
queryFetchSearchResult.setShardIndex(shardId);

queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown);

SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class);
normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext);

List<QuerySearchResult> querySearchResults = queryPhaseResultConsumer.getAtomicArray()
.asList()
.stream()
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());

assertNotNull(querySearchResults);
verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any());
}
}

0 comments on commit e404775

Please sign in to comment.