diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index 657b5c30c..0d6742dbe 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -7,9 +7,7 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; -import java.util.Arrays; import java.util.List; -import java.util.Locale; import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; @@ -20,8 +18,6 @@ import org.opensearch.action.search.SearchPhaseResults; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; -import org.opensearch.search.SearchHit; -import org.opensearch.search.SearchHits; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.fetch.FetchSearchResult; import org.opensearch.search.internal.SearchContext; @@ -102,16 +98,7 @@ private boolean shouldSkipProcessor(SearchPha } QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer) searchPhaseResult; - if (queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery)) { - return true; - } - List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); - Optional fetchSearchResult = getFetchSearchResults(searchPhaseResult); - if (shouldSkipProcessorDueToIncompatibleQueryAndFetchResults(querySearchResults, fetchSearchResult)) { - log.debug("Query and fetch results do not match, normalization processor is skipped"); - return true; - } - return false; + return queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery); } /** @@ -144,41 +131,4 @@ private Optional getFetchS Optional optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst(); return optionalFirstSearchPhaseResult.map(SearchPhaseResult::fetchResult); } - - private boolean shouldSkipProcessorDueToIncompatibleQueryAndFetchResults( - final List querySearchResults, - final Optional fetchSearchResultOptional - ) { - if (fetchSearchResultOptional.isEmpty()) { - return false; - } - final List docIds = unprocessedDocIds(querySearchResults); - SearchHits searchHits = fetchSearchResultOptional.get().hits(); - SearchHit[] searchHitArray = searchHits.getHits(); - // validate the both collections are of the same size - if (Objects.isNull(searchHitArray)) { - log.info("array of search hits in fetch phase results is null"); - return true; - } - if (searchHitArray.length != docIds.size()) { - log.info( - String.format( - Locale.ROOT, - "number of documents in fetch results [%d] and query results [%d] is different", - searchHitArray.length, - docIds.size() - ) - ); - return true; - } - return false; - } - - private List unprocessedDocIds(final List querySearchResults) { - return querySearchResults.isEmpty() - ? List.of() - : Arrays.stream(querySearchResults.get(0).topDocs().topDocs.scoreDocs) - .map(scoreDoc -> scoreDoc.doc) - .collect(Collectors.toList()); - } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 5929370be..71daeac35 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -139,7 +139,7 @@ private void updateOriginalFetchResults( // 3. update original scores to normalized and combined values // 4. order scores based on normalized and combined values FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get(); - SearchHit[] searchHitArray = getSearchHits(fetchSearchResult); + SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult); // create map of docId to index of search hits. This solves (2), duplicates are from // delimiter and start/stop elements, they all have same valid doc_id. For this map @@ -169,9 +169,21 @@ private void updateOriginalFetchResults( fetchSearchResult.hits(updatedSearchHits); } - private SearchHit[] getSearchHits(final FetchSearchResult fetchSearchResult) { + private SearchHit[] getSearchHits(final List docIds, final FetchSearchResult fetchSearchResult) { SearchHits searchHits = fetchSearchResult.hits(); - return searchHits.getHits(); + SearchHit[] searchHitArray = searchHits.getHits(); + // validate the both collections are of the same size + if (Objects.isNull(searchHitArray)) { + throw new IllegalStateException( + "score normalization processor cannot produce final query result, fetch query phase returns empty results" + ); + } + if (searchHitArray.length != docIds.size()) { + throw new IllegalStateException( + "score normalization processor cannot produce final query result, the number of documents returned by fetch and query phases does not match" + ); + } + return searchHitArray; } private List unprocessedDocIds(final List querySearchResults) { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index 46ee122d1..26d9fc808 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.processor; +import static org.hamcrest.Matchers.startsWith; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -417,7 +418,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz verify(normalizationProcessorWorkflow).execute(any(), any(), any(), any()); } - public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenSkipNormalization() { + public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() { NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) ); @@ -489,15 +490,13 @@ public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenSkipNor queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); - normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); - - List querySearchResults = queryPhaseResultConsumer.getAtomicArray() - .asList() - .stream() - .map(result -> result == null ? null : result.queryResult()) - .collect(Collectors.toList()); - - assertNotNull(querySearchResults); - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); + IllegalStateException exception = expectThrows( + IllegalStateException.class, + () -> normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext) + ); + org.hamcrest.MatcherAssert.assertThat( + exception.getMessage(), + startsWith("score normalization processor cannot produce final query result") + ); } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index a8f1d8eb7..95c2ba0c2 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -237,7 +237,7 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom TestUtils.assertFetchResultScores(fetchSearchResult, 4); } - public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenSuccess() { + public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() { NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) ); @@ -282,12 +282,14 @@ public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_then SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); fetchSearchResult.hits(searchHits); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + expectThrows( + IllegalStateException.class, + () -> normalizationProcessorWorkflow.execute( + querySearchResults, + Optional.of(fetchSearchResult), + ScoreNormalizationFactory.DEFAULT_METHOD, + ScoreCombinationFactory.DEFAULT_METHOD + ) ); - TestUtils.assertQueryResultScores(querySearchResults); } }