From c462f4d59221f0d76031e1943d9f8d1d70897124 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 28 Dec 2023 12:15:44 -0800 Subject: [PATCH] Address Navneets comments Signed-off-by: Martin Gaievski --- CHANGELOG.md | 2 +- .../processor/NormalizationProcessor.java | 15 +++- .../NormalizationProcessorWorkflow.java | 18 +---- .../neuralsearch/query/HybridQueryScorer.java | 17 ++++- .../NormalizationProcessorWorkflowTests.java | 15 ++-- .../neuralsearch/query/HybridQueryIT.java | 69 +++++++++++++++++++ 6 files changed, 107 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6522d54f2..c9896ff13 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes -- Multiple identical subqueries in Hybrid query ([#524](https://github.com/opensearch-project/neural-search/pull/524)) +- Fixing multiple issues reported in #497 ([#524](https://github.com/opensearch-project/neural-search/pull/524)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index 6126a3c56..657b5c30c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -9,6 +9,7 @@ import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; @@ -155,7 +156,19 @@ private boolean shouldSkipProcessorDueToIncompatibleQueryAndFetchResults( SearchHits searchHits = fetchSearchResultOptional.get().hits(); SearchHit[] searchHitArray = searchHits.getHits(); // validate the both collections are of the same size - if (Objects.isNull(searchHitArray) || searchHitArray.length != docIds.size()) { + if (Objects.isNull(searchHitArray)) { + log.info("array of search hits in fetch phase results is null"); + return true; + } + if (searchHitArray.length != docIds.size()) { + log.info( + String.format( + Locale.ROOT, + "number of documents in fetch results [%d] and query results [%d] is different", + searchHitArray.length, + docIds.size() + ) + ); return true; } return false; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 55ec63631..5929370be 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -139,7 +139,7 @@ private void updateOriginalFetchResults( // 3. update original scores to normalized and combined values // 4. order scores based on normalized and combined values FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get(); - SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult); + SearchHit[] searchHitArray = getSearchHits(fetchSearchResult); // create map of docId to index of search hits. This solves (2), duplicates are from // delimiter and start/stop elements, they all have same valid doc_id. For this map @@ -169,21 +169,9 @@ private void updateOriginalFetchResults( fetchSearchResult.hits(updatedSearchHits); } - private SearchHit[] getSearchHits(final List docIds, final FetchSearchResult fetchSearchResult) { + private SearchHit[] getSearchHits(final FetchSearchResult fetchSearchResult) { SearchHits searchHits = fetchSearchResult.hits(); - SearchHit[] searchHitArray = searchHits.getHits(); - // validate the both collections are of the same size - if (Objects.isNull(searchHitArray)) { - throw new IllegalStateException( - "Score normalization processor cannot produce final query result, for one shard case fetch does not have any results" - ); - } - if (searchHitArray.length != docIds.size()) { - throw new IllegalStateException( - "Score normalization processor cannot produce final query result, for one shard case number of fetched documents does not match number of search hits" - ); - } - return searchHitArray; + return searchHits.getHits(); } private List unprocessedDocIds(final List querySearchResults) { diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index ca2f06fc3..57ad4451f 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -5,6 +5,8 @@ package org.opensearch.neuralsearch.query; +import static java.util.Locale.ROOT; + import java.io.IOException; import java.util.ArrayList; import java.util.Collections; @@ -118,12 +120,21 @@ public float[] hybridScores() throws IOException { } Query query = scorer.getWeight().getQuery(); List indexes = queryToIndex.get(query); - // we need to find the index of first sub-query that hasn't been updated yet + // we need to find the index of first sub-query that hasn't been set yet. Such score will have initial value of "0.0" int index = indexes.stream() .mapToInt(idx -> idx) - .filter(index1 -> Float.compare(scores[index1], 0.0f) == 0) + .filter(idx -> Float.compare(scores[idx], 0.0f) == 0) .findFirst() - .orElseThrow(() -> new IllegalStateException("cannot collect score for subquery")); + .orElseThrow( + () -> new IllegalStateException( + String.format( + ROOT, + "cannot set score for one of hybrid search subquery [%s] and document [%d]", + query.toString(), + scorer.docID() + ) + ) + ); scores[index] = scorer.score(); } return scores; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index 95c2ba0c2..e52e24672 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -237,7 +237,7 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom TestUtils.assertFetchResultScores(fetchSearchResult, 4); } - public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() { + public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenSuccess() { NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) ); @@ -282,14 +282,11 @@ public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_then SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); fetchSearchResult.hits(searchHits); - expectThrows( - IllegalStateException.class, - () -> normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD - ) + normalizationProcessorWorkflow.execute( + querySearchResults, + Optional.of(fetchSearchResult), + ScoreNormalizationFactory.DEFAULT_METHOD, + ScoreCombinationFactory.DEFAULT_METHOD ); } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 312cb8b3a..36613ef1b 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -174,6 +174,75 @@ public void testComplexQuery_whenMultipleSubqueries_thenSuccessful() { assertEquals(RELATION_EQUAL_TO, total.get("relation")); } + /** + * Tests complex query with multiple nested sub-queries, where soem sub-queries are same + * { + * "query": { + * "hybrid": { + * "queries": [ + * { + * "term": { + * "text": "word1" + * } + * }, + * { + * "term": { + * "text": "word2" + * } + * }, + * { + * "term": { + * "text": "word3" + * } + * } + * ] + * } + * } + * } + */ + @SneakyThrows + public void testComplexQuery_whenMultipleIdenticalSubQueries_thenSuccessful() { + initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); + TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + + HybridQueryBuilder hybridQueryBuilderThreeTerms = new HybridQueryBuilder(); + hybridQueryBuilderThreeTerms.add(termQueryBuilder1); + hybridQueryBuilderThreeTerms.add(termQueryBuilder2); + hybridQueryBuilderThreeTerms.add(termQueryBuilder3); + + Map searchResponseAsMap1 = search( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + hybridQueryBuilderThreeTerms, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertEquals(2, getHitCount(searchResponseAsMap1)); + + List> hits1NestedList = getNestedHits(searchResponseAsMap1); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hits1NestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap1); + assertNotNull(total.get("value")); + assertEquals(2, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + @SneakyThrows public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME);