Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.11] Fixed exception in Hybrid Query for one shard and multiple node #400

Merged
merged 1 commit into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Added Multimodal semantic search feature ([#359](https://github.com/opensearch-p
### Enhancements
Add `max_token_score` parameter to improve the execution efficiency for `neural_sparse` query clause ([#348](https://github.com/opensearch-project/neural-search/pull/348))
### Bug Fixes
Fixed exception in Hybrid Query for one shard and multiple node ([#396](https://github.com/opensearch-project/neural-search/pull/396))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
package org.opensearch.neuralsearch.processor;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
Expand Down Expand Up @@ -52,6 +52,9 @@ public void execute(
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
) {
// save original state
List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults);

// pre-process data
log.debug("Pre-process query results");
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(querySearchResults);
Expand All @@ -67,7 +70,7 @@ public void execute(
// post-process data
log.debug("Post-process query results after score normalization and combination");
updateOriginalQueryResults(querySearchResults, queryTopDocs);
updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional);
updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds);
}

/**
Expand Down Expand Up @@ -123,7 +126,8 @@ private void updateOriginalQueryResults(final List<QuerySearchResult> querySearc
*/
private void updateOriginalFetchResults(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional
final Optional<FetchSearchResult> fetchSearchResultOptional,
final List<Integer> docIds
) {
if (fetchSearchResultOptional.isEmpty()) {
return;
Expand All @@ -135,14 +139,17 @@ private void updateOriginalFetchResults(
// 3. update original scores to normalized and combined values
// 4. order scores based on normalized and combined values
FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get();
SearchHits searchHits = fetchSearchResult.hits();
SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult);

// create map of docId to index of search hits. This solves (2), duplicates are from
// delimiter and start/stop elements, they all have same valid doc_id. For this map
// we use doc_id as a key, and all those special elements are collapsed into a single
// key-value pair.
Map<Integer, SearchHit> docIdToSearchHit = Arrays.stream(searchHits.getHits())
.collect(Collectors.toMap(SearchHit::docId, Function.identity(), (a1, a2) -> a1));
Map<Integer, SearchHit> docIdToSearchHit = new HashMap<>();
for (int i = 0; i < searchHitArray.length; i++) {
int originalDocId = docIds.get(i);
docIdToSearchHit.put(originalDocId, searchHitArray[i]);
}

QuerySearchResult querySearchResult = querySearchResults.get(0);
TopDocs topDocs = querySearchResult.topDocs().topDocs;
Expand All @@ -161,4 +168,23 @@ private void updateOriginalFetchResults(
);
fetchSearchResult.hits(updatedSearchHits);
}

private SearchHit[] getSearchHits(final List<Integer> docIds, final FetchSearchResult fetchSearchResult) {
SearchHits searchHits = fetchSearchResult.hits();
SearchHit[] searchHitArray = searchHits.getHits();
// validate the both collections are of the same size
if (Objects.isNull(searchHitArray) || searchHitArray.length != docIds.size()) {
throw new IllegalStateException("Score normalization processor cannot produce final query result");
}
return searchHitArray;
}

private List<Integer> unprocessedDocIds(final List<QuerySearchResult> querySearchResults) {
List<Integer> docIds = querySearchResults.isEmpty()
? List.of()
: Arrays.stream(querySearchResults.get(0).topDocs().topDocs.scoreDocs)
.map(scoreDoc -> scoreDoc.doc)
.collect(Collectors.toList());
return docIds;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,117 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo
TestUtils.assertQueryResultScores(querySearchResults);
TestUtils.assertFetchResultScores(fetchSearchResult, 4);
}

public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCombination() {
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);

List<QuerySearchResult> querySearchResults = new ArrayList<>();
FetchSearchResult fetchSearchResult = new FetchSearchResult();
int shardId = 0;
SearchShardTarget searchShardTarget = new SearchShardTarget(
"node",
new ShardId("index", "uuid", shardId),
null,
OriginalIndices.NONE
);
QuerySearchResult querySearchResult = new QuerySearchResult();
querySearchResult.topDocs(
new TopDocsAndMaxScore(
new TopDocs(
new TotalHits(4, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] {
createStartStopElementForHybridSearchResults(0),
createDelimiterElementForHybridSearchResults(0),
new ScoreDoc(0, 0.5f),
new ScoreDoc(2, 0.3f),
new ScoreDoc(4, 0.25f),
new ScoreDoc(10, 0.2f),
createStartStopElementForHybridSearchResults(0) }
),
0.5f
),
new DocValueFormat[0]
);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);
querySearchResults.add(querySearchResult);
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "1", Map.of(), Map.of()),
new SearchHit(-1, "2", Map.of(), Map.of()),
new SearchHit(-1, "3", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()), };
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10);
fetchSearchResult.hits(searchHits);

normalizationProcessorWorkflow.execute(
querySearchResults,
Optional.of(fetchSearchResult),
ScoreNormalizationFactory.DEFAULT_METHOD,
ScoreCombinationFactory.DEFAULT_METHOD
);

TestUtils.assertQueryResultScores(querySearchResults);
TestUtils.assertFetchResultScores(fetchSearchResult, 4);
}

public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() {
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);

List<QuerySearchResult> querySearchResults = new ArrayList<>();
FetchSearchResult fetchSearchResult = new FetchSearchResult();
int shardId = 0;
SearchShardTarget searchShardTarget = new SearchShardTarget(
"node",
new ShardId("index", "uuid", shardId),
null,
OriginalIndices.NONE
);
QuerySearchResult querySearchResult = new QuerySearchResult();
querySearchResult.topDocs(
new TopDocsAndMaxScore(
new TopDocs(
new TotalHits(4, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] {
createStartStopElementForHybridSearchResults(0),
createDelimiterElementForHybridSearchResults(0),
new ScoreDoc(0, 0.5f),
new ScoreDoc(2, 0.3f),
new ScoreDoc(4, 0.25f),
new ScoreDoc(10, 0.2f),
createStartStopElementForHybridSearchResults(0) }
),
0.5f
),
new DocValueFormat[0]
);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);
querySearchResults.add(querySearchResult);
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "1", Map.of(), Map.of()),
new SearchHit(-1, "2", Map.of(), Map.of()),
new SearchHit(-1, "3", Map.of(), Map.of()) };
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10);
fetchSearchResult.hits(searchHits);

expectThrows(
IllegalStateException.class,
() -> normalizationProcessorWorkflow.execute(
querySearchResults,
Optional.of(fetchSearchResult),
ScoreNormalizationFactory.DEFAULT_METHOD,
ScoreCombinationFactory.DEFAULT_METHOD
)
);
}
}