Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed exception in Hybrid Query for one shard and multiple node #396

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333))
### Enhancements
### Bug Fixes
Fixed exception in Hybrid Query for one shard and multiple node ([#396](https://github.com/opensearch-project/neural-search/pull/396))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
package org.opensearch.neuralsearch.processor;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
Expand Down Expand Up @@ -52,6 +52,9 @@
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
) {
// save original state
List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults);

// pre-process data
log.debug("Pre-process query results");
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(querySearchResults);
Expand All @@ -67,7 +70,7 @@
// post-process data
log.debug("Post-process query results after score normalization and combination");
updateOriginalQueryResults(querySearchResults, queryTopDocs);
updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional);
updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds);
}

/**
Expand Down Expand Up @@ -123,7 +126,8 @@
*/
private void updateOriginalFetchResults(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional
final Optional<FetchSearchResult> fetchSearchResultOptional,
final List<Integer> docIds
) {
if (fetchSearchResultOptional.isEmpty()) {
return;
Expand All @@ -135,14 +139,17 @@
// 3. update original scores to normalized and combined values
// 4. order scores based on normalized and combined values
FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get();
SearchHits searchHits = fetchSearchResult.hits();
SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult);

// create map of docId to index of search hits. This solves (2), duplicates are from
// delimiter and start/stop elements, they all have same valid doc_id. For this map
// we use doc_id as a key, and all those special elements are collapsed into a single
// key-value pair.
Map<Integer, SearchHit> docIdToSearchHit = Arrays.stream(searchHits.getHits())
.collect(Collectors.toMap(SearchHit::docId, Function.identity(), (a1, a2) -> a1));
Map<Integer, SearchHit> docIdToSearchHit = new HashMap<>();
for (int i = 0; i < searchHitArray.length; i++) {
int originalDocId = docIds.get(i);
docIdToSearchHit.put(originalDocId, searchHitArray[i]);
}

QuerySearchResult querySearchResult = querySearchResults.get(0);
TopDocs topDocs = querySearchResult.topDocs().topDocs;
Expand All @@ -161,4 +168,23 @@
);
fetchSearchResult.hits(updatedSearchHits);
}

private SearchHit[] getSearchHits(List<Integer> docIds, FetchSearchResult fetchSearchResult) {
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
SearchHits searchHits = fetchSearchResult.hits();
SearchHit[] searchHitArray = searchHits.getHits();
// validate the both collections are of the same size
if (Objects.isNull(searchHitArray) || searchHitArray.length != docIds.size()) {
throw new IllegalStateException("Score normalization processor cannot produce final query result");
}
return searchHitArray;
}

private List<Integer> unprocessedDocIds(List<QuerySearchResult> querySearchResults) {
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
List<Integer> docIds = querySearchResults.isEmpty()
? List.of()

Check warning on line 184 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java#L184

Added line #L184 was not covered by tests
: Arrays.stream(querySearchResults.get(0).topDocs().topDocs.scoreDocs)
.map(scoreDoc -> scoreDoc.doc)
.collect(Collectors.toList());
return docIds;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,117 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo
TestUtils.assertQueryResultScores(querySearchResults);
TestUtils.assertFetchResultScores(fetchSearchResult, 4);
}

public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCombination() {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test case would fail with old code

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding a negative case.

NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);

List<QuerySearchResult> querySearchResults = new ArrayList<>();
FetchSearchResult fetchSearchResult = new FetchSearchResult();
int shardId = 0;
SearchShardTarget searchShardTarget = new SearchShardTarget(
"node",
new ShardId("index", "uuid", shardId),
null,
OriginalIndices.NONE
);
QuerySearchResult querySearchResult = new QuerySearchResult();
querySearchResult.topDocs(
new TopDocsAndMaxScore(
new TopDocs(
new TotalHits(4, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] {
createStartStopElementForHybridSearchResults(0),
createDelimiterElementForHybridSearchResults(0),
new ScoreDoc(0, 0.5f),
new ScoreDoc(2, 0.3f),
new ScoreDoc(4, 0.25f),
new ScoreDoc(10, 0.2f),
createStartStopElementForHybridSearchResults(0) }
),
0.5f
),
new DocValueFormat[0]
);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);
querySearchResults.add(querySearchResult);
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "1", Map.of(), Map.of()),
new SearchHit(-1, "2", Map.of(), Map.of()),
new SearchHit(-1, "3", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()), };
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10);
fetchSearchResult.hits(searchHits);

normalizationProcessorWorkflow.execute(
querySearchResults,
Optional.of(fetchSearchResult),
ScoreNormalizationFactory.DEFAULT_METHOD,
ScoreCombinationFactory.DEFAULT_METHOD
);

TestUtils.assertQueryResultScores(querySearchResults);
TestUtils.assertFetchResultScores(fetchSearchResult, 4);
}

public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() {
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);

List<QuerySearchResult> querySearchResults = new ArrayList<>();
FetchSearchResult fetchSearchResult = new FetchSearchResult();
int shardId = 0;
SearchShardTarget searchShardTarget = new SearchShardTarget(
"node",
new ShardId("index", "uuid", shardId),
null,
OriginalIndices.NONE
);
QuerySearchResult querySearchResult = new QuerySearchResult();
querySearchResult.topDocs(
new TopDocsAndMaxScore(
new TopDocs(
new TotalHits(4, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] {
createStartStopElementForHybridSearchResults(0),
createDelimiterElementForHybridSearchResults(0),
new ScoreDoc(0, 0.5f),
new ScoreDoc(2, 0.3f),
new ScoreDoc(4, 0.25f),
new ScoreDoc(10, 0.2f),
createStartStopElementForHybridSearchResults(0) }
),
0.5f
),
new DocValueFormat[0]
);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);
querySearchResults.add(querySearchResult);
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "10", Map.of(), Map.of()),
new SearchHit(-1, "1", Map.of(), Map.of()),
new SearchHit(-1, "2", Map.of(), Map.of()),
new SearchHit(-1, "3", Map.of(), Map.of()) };
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10);
fetchSearchResult.hits(searchHits);

expectThrows(
IllegalStateException.class,
() -> normalizationProcessorWorkflow.execute(
querySearchResults,
Optional.of(fetchSearchResult),
ScoreNormalizationFactory.DEFAULT_METHOD,
ScoreCombinationFactory.DEFAULT_METHOD
)
);
}
}
Loading